master
1const Decompress = @This();
2const std = @import("std");
3const assert = std.debug.assert;
4const Reader = std.Io.Reader;
5const Limit = std.Io.Limit;
6const zstd = @import("../zstd.zig");
7const Writer = std.Io.Writer;
8
9input: *Reader,
10reader: Reader,
11state: State,
12verify_checksum: bool,
13window_len: u32,
14err: ?Error = null,
15
16const State = union(enum) {
17 new_frame,
18 in_frame: InFrame,
19 skipping_frame: usize,
20
21 const InFrame = struct {
22 frame: Frame,
23 checksum: ?u32,
24 decompressed_size: usize,
25 decode: Frame.Zstandard.Decode,
26 };
27};
28
29pub const Options = struct {
30 /// Verifying checksums is not implemented yet and will cause a panic if
31 /// you set this to true.
32 verify_checksum: bool = false,
33
34 /// The output buffer is asserted to have capacity for `window_len` plus
35 /// `zstd.block_size_max`.
36 ///
37 /// If `window_len` is too small, then some streams will fail to decompress
38 /// with `error.OutputBufferUndersize`.
39 window_len: u32 = zstd.default_window_len,
40};
41
42pub const Error = error{
43 BadMagic,
44 BlockOversize,
45 ChecksumFailure,
46 ContentOversize,
47 DictionaryIdFlagUnsupported,
48 EndOfStream,
49 HuffmanTreeIncomplete,
50 InvalidBitStream,
51 MalformedAccuracyLog,
52 MalformedBlock,
53 MalformedCompressedBlock,
54 MalformedFrame,
55 MalformedFseBits,
56 MalformedFseTable,
57 MalformedHuffmanTree,
58 MalformedLiteralsHeader,
59 MalformedLiteralsLength,
60 MalformedLiteralsSection,
61 MalformedSequence,
62 MissingStartBit,
63 OutputBufferUndersize,
64 InputBufferUndersize,
65 ReadFailed,
66 RepeatModeFirst,
67 ReservedBitSet,
68 ReservedBlock,
69 SequenceBufferUndersize,
70 TreelessLiteralsFirst,
71 UnexpectedEndOfLiteralStream,
72 WindowOversize,
73 WindowSizeUnknown,
74};
75
76const direct_vtable: Reader.VTable = .{
77 .stream = streamDirect,
78 .rebase = rebaseFallible,
79 .discard = discardDirect,
80 .readVec = readVec,
81};
82
83const indirect_vtable: Reader.VTable = .{
84 .stream = streamIndirect,
85 .rebase = rebaseFallible,
86 .discard = discardIndirect,
87 .readVec = readVec,
88};
89
90/// When connecting `reader` to a `Writer`, `buffer` should be empty, and
91/// `Writer.buffer` capacity has requirements based on `Options.window_len`.
92///
93/// Otherwise, `buffer` has those requirements.
94pub fn init(input: *Reader, buffer: []u8, options: Options) Decompress {
95 if (buffer.len != 0) assert(buffer.len >= options.window_len + zstd.block_size_max);
96 return .{
97 .input = input,
98 .state = .new_frame,
99 .verify_checksum = options.verify_checksum,
100 .window_len = options.window_len,
101 .reader = .{
102 .vtable = if (buffer.len == 0) &direct_vtable else &indirect_vtable,
103 .buffer = buffer,
104 .seek = 0,
105 .end = 0,
106 },
107 };
108}
109
110fn streamDirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
111 const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
112 return stream(d, w, limit);
113}
114
115fn streamIndirect(r: *Reader, w: *Writer, limit: std.Io.Limit) Reader.StreamError!usize {
116 const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
117 _ = limit;
118 _ = w;
119 return streamIndirectInner(d);
120}
121
122fn rebaseFallible(r: *Reader, capacity: usize) Reader.RebaseError!void {
123 rebase(r, capacity);
124}
125
126fn rebase(r: *Reader, capacity: usize) void {
127 const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
128 assert(capacity <= r.buffer.len - d.window_len);
129 assert(r.end + capacity > r.buffer.len);
130 const discard_n = @min(r.seek, r.end - d.window_len);
131 const keep = r.buffer[discard_n..r.end];
132 @memmove(r.buffer[0..keep.len], keep);
133 r.end = keep.len;
134 r.seek -= discard_n;
135}
136
137/// This could be improved so that when an amount is discarded that includes an
138/// entire frame, skip decoding that frame.
139fn discardDirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
140 const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
141 rebase(r, d.window_len);
142 var writer: Writer = .{
143 .vtable = &.{
144 .drain = std.Io.Writer.Discarding.drain,
145 .sendFile = std.Io.Writer.Discarding.sendFile,
146 },
147 .buffer = r.buffer,
148 .end = r.end,
149 };
150 defer {
151 r.end = writer.end;
152 r.seek = r.end;
153 }
154 const n = r.stream(&writer, limit) catch |err| switch (err) {
155 error.WriteFailed => unreachable,
156 error.ReadFailed => return error.ReadFailed,
157 error.EndOfStream => return error.EndOfStream,
158 };
159 assert(n <= @intFromEnum(limit));
160 return n;
161}
162
163fn discardIndirect(r: *Reader, limit: std.Io.Limit) Reader.Error!usize {
164 const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
165 rebase(r, d.window_len);
166 var writer: Writer = .{
167 .buffer = r.buffer,
168 .end = r.end,
169 .vtable = &.{ .drain = Writer.unreachableDrain },
170 };
171 {
172 defer r.end = writer.end;
173 _ = stream(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
174 error.WriteFailed => unreachable,
175 else => |e| return e,
176 };
177 }
178 const n = limit.minInt(r.end - r.seek);
179 r.seek += n;
180 return n;
181}
182
183fn readVec(r: *Reader, data: [][]u8) Reader.Error!usize {
184 _ = data;
185 const d: *Decompress = @alignCast(@fieldParentPtr("reader", r));
186 return streamIndirectInner(d);
187}
188
189fn streamIndirectInner(d: *Decompress) Reader.Error!usize {
190 const r = &d.reader;
191 if (r.buffer.len - r.end < zstd.block_size_max) rebase(r, zstd.block_size_max);
192 assert(r.buffer.len - r.end >= zstd.block_size_max);
193 var writer: Writer = .{
194 .buffer = r.buffer,
195 .end = r.end,
196 .vtable = &.{
197 .drain = Writer.unreachableDrain,
198 .rebase = Writer.unreachableRebase,
199 },
200 };
201 defer r.end = writer.end;
202 _ = stream(d, &writer, .limited(writer.buffer.len - writer.end)) catch |err| switch (err) {
203 error.WriteFailed => unreachable,
204 else => |e| return e,
205 };
206 return 0;
207}
208
209fn stream(d: *Decompress, w: *Writer, limit: Limit) Reader.StreamError!usize {
210 const in = d.input;
211
212 state: switch (d.state) {
213 .new_frame => {
214 // Only return EndOfStream when there are exactly 0 bytes remaining on the
215 // frame magic. Any partial magic bytes should be considered a failure.
216 in.fill(@sizeOf(Frame.Magic)) catch |err| switch (err) {
217 error.EndOfStream => {
218 if (in.bufferedLen() != 0) {
219 d.err = error.BadMagic;
220 return error.ReadFailed;
221 }
222 return err;
223 },
224 else => |e| return e,
225 };
226 const magic = try in.takeEnumNonexhaustive(Frame.Magic, .little);
227 initFrame(d, magic) catch |err| {
228 d.err = err;
229 return error.ReadFailed;
230 };
231 continue :state d.state;
232 },
233 .in_frame => |*in_frame| {
234 return readInFrame(d, w, limit, in_frame) catch |err| switch (err) {
235 error.ReadFailed => return error.ReadFailed,
236 error.WriteFailed => return error.WriteFailed,
237 else => |e| {
238 d.err = e;
239 return error.ReadFailed;
240 },
241 };
242 },
243 .skipping_frame => |*remaining| {
244 const n = in.discard(.limited(remaining.*)) catch |err| {
245 d.err = err;
246 return error.ReadFailed;
247 };
248 remaining.* -= n;
249 if (remaining.* == 0) d.state = .new_frame;
250 return 0;
251 },
252 }
253}
254
255fn initFrame(d: *Decompress, magic: Frame.Magic) !void {
256 const in = d.input;
257 switch (magic.kind() orelse return error.BadMagic) {
258 .zstandard => {
259 const header = try Frame.Zstandard.Header.decode(in);
260 d.state = .{ .in_frame = .{
261 .frame = try Frame.init(header, d.window_len, d.verify_checksum),
262 .checksum = null,
263 .decompressed_size = 0,
264 .decode = .init,
265 } };
266 },
267 .skippable => {
268 const frame_size = try in.takeInt(u32, .little);
269 d.state = .{ .skipping_frame = frame_size };
270 },
271 }
272}
273
274fn readInFrame(d: *Decompress, w: *Writer, limit: Limit, state: *State.InFrame) !usize {
275 const in = d.input;
276 const window_len = d.window_len;
277
278 const block_header = try in.takeStruct(Frame.Zstandard.Block.Header, .little);
279 const block_size = block_header.size;
280 const frame_block_size_max = state.frame.block_size_max;
281 if (frame_block_size_max < block_size) return error.BlockOversize;
282 if (@intFromEnum(limit) < block_size) return error.OutputBufferUndersize;
283 var bytes_written: usize = 0;
284 switch (block_header.type) {
285 .raw => {
286 try in.streamExactPreserve(w, window_len, block_size);
287 bytes_written = block_size;
288 },
289 .rle => {
290 const byte = try in.takeByte();
291 try w.splatBytePreserve(window_len, byte, block_size);
292 bytes_written = block_size;
293 },
294 .compressed => {
295 var literals_buffer: [zstd.block_size_max]u8 = undefined;
296 var sequence_buffer: [zstd.block_size_max]u8 = undefined;
297 var remaining: Limit = .limited(block_size);
298 const literals = try LiteralsSection.decode(in, &remaining, &literals_buffer);
299 const sequences_header = try SequencesSection.Header.decode(in, &remaining);
300
301 const decode = &state.decode;
302 try decode.prepare(in, &remaining, literals, sequences_header);
303
304 {
305 if (sequence_buffer.len < @intFromEnum(remaining))
306 return error.SequenceBufferUndersize;
307 const seq_slice = remaining.slice(&sequence_buffer);
308 try in.readSliceAll(seq_slice);
309 var bit_stream = try ReverseBitReader.init(seq_slice);
310
311 if (sequences_header.sequence_count > 0) {
312 try decode.readInitialFseState(&bit_stream);
313
314 // Ensures the following calls to `decodeSequence` will not flush.
315 const dest = (try w.writableSliceGreedyPreserve(window_len, frame_block_size_max))[0..frame_block_size_max];
316 const write_pos = dest.ptr - w.buffer.ptr;
317 for (0..sequences_header.sequence_count - 1) |_| {
318 bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream);
319 try decode.updateState(.literal, &bit_stream);
320 try decode.updateState(.match, &bit_stream);
321 try decode.updateState(.offset, &bit_stream);
322 }
323 bytes_written += try decode.decodeSequence(w.buffer, write_pos + bytes_written, &bit_stream);
324 if (bytes_written > dest.len) return error.MalformedSequence;
325 w.advance(bytes_written);
326 }
327
328 if (!bit_stream.isEmpty()) {
329 return error.MalformedCompressedBlock;
330 }
331 }
332
333 if (decode.literal_written_count < literals.header.regenerated_size) {
334 const len = literals.header.regenerated_size - decode.literal_written_count;
335 try decode.decodeLiterals(w, len);
336 decode.literal_written_count += len;
337 bytes_written += len;
338 }
339
340 switch (decode.literal_header.block_type) {
341 .treeless, .compressed => {
342 if (!decode.isLiteralStreamEmpty()) return error.MalformedCompressedBlock;
343 },
344 .raw, .rle => {},
345 }
346
347 if (bytes_written > frame_block_size_max) return error.BlockOversize;
348 },
349 .reserved => return error.ReservedBlock,
350 }
351
352 if (state.frame.hasher_opt) |*hasher| {
353 if (bytes_written > 0) {
354 _ = hasher;
355 @panic("TODO all those bytes written needed to go through the hasher too");
356 }
357 }
358
359 state.decompressed_size += bytes_written;
360
361 if (block_header.last) {
362 if (state.frame.has_checksum) {
363 const expected_checksum = try in.takeInt(u32, .little);
364 if (state.frame.hasher_opt) |*hasher| {
365 const actual_checksum: u32 = @truncate(hasher.final());
366 if (expected_checksum != actual_checksum) return error.ChecksumFailure;
367 }
368 }
369 if (state.frame.content_size) |content_size| {
370 if (content_size != state.decompressed_size) {
371 return error.MalformedFrame;
372 }
373 }
374 d.state = .new_frame;
375 } else if (state.frame.content_size) |content_size| {
376 if (state.decompressed_size > content_size) return error.MalformedFrame;
377 }
378
379 return bytes_written;
380}
381
382pub const Frame = struct {
383 hasher_opt: ?std.hash.XxHash64,
384 window_size: usize,
385 has_checksum: bool,
386 block_size_max: usize,
387 content_size: ?usize,
388
389 pub const Magic = enum(u32) {
390 zstandard = 0xFD2FB528,
391 _,
392
393 pub fn kind(m: Magic) ?Kind {
394 return switch (@intFromEnum(m)) {
395 @intFromEnum(Magic.zstandard) => .zstandard,
396 @intFromEnum(Skippable.magic_min)...@intFromEnum(Skippable.magic_max) => .skippable,
397 else => null,
398 };
399 }
400
401 pub fn isSkippable(m: Magic) bool {
402 return switch (@intFromEnum(m)) {
403 @intFromEnum(Skippable.magic_min)...@intFromEnum(Skippable.magic_max) => true,
404 else => false,
405 };
406 }
407 };
408
409 pub const Kind = enum { zstandard, skippable };
410
411 pub const Zstandard = struct {
412 pub const magic: Magic = .zstandard;
413
414 header: Header,
415 data_blocks: []Block,
416 checksum: ?u32,
417
418 pub const Header = struct {
419 descriptor: Descriptor,
420 window_descriptor: ?u8,
421 dictionary_id: ?u32,
422 content_size: ?u64,
423
424 pub const Descriptor = packed struct {
425 dictionary_id_flag: u2,
426 content_checksum_flag: bool,
427 reserved: bool,
428 unused: bool,
429 single_segment_flag: bool,
430 content_size_flag: u2,
431 };
432
433 pub const DecodeError = Reader.Error || error{ReservedBitSet};
434
435 pub fn decode(in: *Reader) DecodeError!Header {
436 const descriptor: Descriptor = @bitCast(try in.takeByte());
437
438 if (descriptor.reserved) return error.ReservedBitSet;
439
440 const window_descriptor: ?u8 = if (descriptor.single_segment_flag) null else try in.takeByte();
441
442 const dictionary_id: ?u32 = if (descriptor.dictionary_id_flag > 0) d: {
443 // if flag is 3 then field_size = 4, else field_size = flag
444 const field_size = (@as(u4, 1) << descriptor.dictionary_id_flag) >> 1;
445 break :d try in.takeVarInt(u32, .little, field_size);
446 } else null;
447
448 const content_size: ?u64 = if (descriptor.single_segment_flag or descriptor.content_size_flag > 0) c: {
449 const field_size = @as(u4, 1) << descriptor.content_size_flag;
450 const content_size = try in.takeVarInt(u64, .little, field_size);
451 break :c if (field_size == 2) content_size + 256 else content_size;
452 } else null;
453
454 return .{
455 .descriptor = descriptor,
456 .window_descriptor = window_descriptor,
457 .dictionary_id = dictionary_id,
458 .content_size = content_size,
459 };
460 }
461
462 /// Returns the window size required to decompress a frame, or `null` if it
463 /// cannot be determined (which indicates a malformed frame header).
464 pub fn windowSize(header: Header) ?u64 {
465 if (header.window_descriptor) |descriptor| {
466 const exponent = (descriptor & 0b11111000) >> 3;
467 const mantissa = descriptor & 0b00000111;
468 const window_log = 10 + exponent;
469 const window_base = @as(u64, 1) << @as(u6, @intCast(window_log));
470 const window_add = (window_base / 8) * mantissa;
471 return window_base + window_add;
472 } else return header.content_size;
473 }
474 };
475
476 pub const Block = struct {
477 pub const Header = packed struct(u24) {
478 last: bool,
479 type: Type,
480 size: u21,
481 };
482
483 pub const Type = enum(u2) {
484 raw,
485 rle,
486 compressed,
487 reserved,
488 };
489 };
490
491 pub const Decode = struct {
492 repeat_offsets: [3]u32,
493
494 offset: StateData(8),
495 match: StateData(9),
496 literal: StateData(9),
497
498 literal_fse_buffer: [zstd.table_size_max.literal]Table.Fse,
499 match_fse_buffer: [zstd.table_size_max.match]Table.Fse,
500 offset_fse_buffer: [zstd.table_size_max.offset]Table.Fse,
501
502 fse_tables_undefined: bool,
503
504 literal_stream_reader: ReverseBitReader,
505 literal_stream_index: usize,
506 literal_streams: LiteralsSection.Streams,
507 literal_header: LiteralsSection.Header,
508 huffman_tree: ?LiteralsSection.HuffmanTree,
509
510 literal_written_count: usize,
511
512 fn StateData(comptime max_accuracy_log: comptime_int) type {
513 return struct {
514 state: @This().State,
515 table: Table,
516 accuracy_log: u8,
517
518 const State = std.meta.Int(.unsigned, max_accuracy_log);
519 };
520 }
521
522 const init: Decode = .{
523 .repeat_offsets = .{
524 zstd.start_repeated_offset_1,
525 zstd.start_repeated_offset_2,
526 zstd.start_repeated_offset_3,
527 },
528
529 .offset = undefined,
530 .match = undefined,
531 .literal = undefined,
532
533 .literal_fse_buffer = undefined,
534 .match_fse_buffer = undefined,
535 .offset_fse_buffer = undefined,
536
537 .fse_tables_undefined = true,
538
539 .literal_written_count = 0,
540 .literal_header = undefined,
541 .literal_streams = undefined,
542 .literal_stream_reader = undefined,
543 .literal_stream_index = undefined,
544 .huffman_tree = null,
545 };
546
547 pub const PrepareError = error{
548 /// the (reversed) literal bitstream's first byte does not have any bits set
549 MissingStartBit,
550 /// `literals` is a treeless literals section and the decode state does not
551 /// have a Huffman tree from a previous block
552 TreelessLiteralsFirst,
553 /// on the first call if one of the sequence FSE tables is set to repeat mode
554 RepeatModeFirst,
555 /// an FSE table has an invalid accuracy
556 MalformedAccuracyLog,
557 /// failed decoding an FSE table
558 MalformedFseTable,
559 /// input stream ends before all FSE tables are read
560 EndOfStream,
561 ReadFailed,
562 InputBufferUndersize,
563 };
564
565 /// Prepare the decoder to decode a compressed block. Loads the
566 /// literals stream and Huffman tree from `literals` and reads the
567 /// FSE tables from `in`.
568 pub fn prepare(
569 self: *Decode,
570 in: *Reader,
571 remaining: *Limit,
572 literals: LiteralsSection,
573 sequences_header: SequencesSection.Header,
574 ) PrepareError!void {
575 self.literal_written_count = 0;
576 self.literal_header = literals.header;
577 self.literal_streams = literals.streams;
578
579 if (literals.huffman_tree) |tree| {
580 self.huffman_tree = tree;
581 } else if (literals.header.block_type == .treeless and self.huffman_tree == null) {
582 return error.TreelessLiteralsFirst;
583 }
584
585 switch (literals.header.block_type) {
586 .raw, .rle => {},
587 .compressed, .treeless => {
588 self.literal_stream_index = 0;
589 switch (literals.streams) {
590 .one => |slice| try self.initLiteralStream(slice),
591 .four => |streams| try self.initLiteralStream(streams[0]),
592 }
593 },
594 }
595
596 if (sequences_header.sequence_count > 0) {
597 try self.updateFseTable(in, remaining, .literal, sequences_header.literal_lengths);
598 try self.updateFseTable(in, remaining, .offset, sequences_header.offsets);
599 try self.updateFseTable(in, remaining, .match, sequences_header.match_lengths);
600 self.fse_tables_undefined = false;
601 }
602 }
603
604 /// Read initial FSE states for sequence decoding.
605 pub fn readInitialFseState(self: *Decode, bit_reader: *ReverseBitReader) error{EndOfStream}!void {
606 self.literal.state = try bit_reader.readBitsNoEof(u9, self.literal.accuracy_log);
607 self.offset.state = try bit_reader.readBitsNoEof(u8, self.offset.accuracy_log);
608 self.match.state = try bit_reader.readBitsNoEof(u9, self.match.accuracy_log);
609 }
610
611 fn updateRepeatOffset(self: *Decode, offset: u32) void {
612 self.repeat_offsets[2] = self.repeat_offsets[1];
613 self.repeat_offsets[1] = self.repeat_offsets[0];
614 self.repeat_offsets[0] = offset;
615 }
616
617 fn useRepeatOffset(self: *Decode, index: usize) u32 {
618 if (index == 1)
619 std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[1])
620 else if (index == 2) {
621 std.mem.swap(u32, &self.repeat_offsets[0], &self.repeat_offsets[2]);
622 std.mem.swap(u32, &self.repeat_offsets[1], &self.repeat_offsets[2]);
623 }
624 return self.repeat_offsets[0];
625 }
626
627 const WhichFse = enum { offset, match, literal };
628
629 /// TODO: don't use `@field`
630 fn updateState(
631 self: *Decode,
632 comptime choice: WhichFse,
633 bit_reader: *ReverseBitReader,
634 ) error{ MalformedFseBits, EndOfStream }!void {
635 switch (@field(self, @tagName(choice)).table) {
636 .rle => {},
637 .fse => |table| {
638 const data = table[@field(self, @tagName(choice)).state];
639 const T = @TypeOf(@field(self, @tagName(choice))).State;
640 const bits_summand = try bit_reader.readBitsNoEof(T, data.bits);
641 const next_state = std.math.cast(
642 @TypeOf(@field(self, @tagName(choice))).State,
643 data.baseline + bits_summand,
644 ) orelse return error.MalformedFseBits;
645 @field(self, @tagName(choice)).state = next_state;
646 },
647 }
648 }
649
650 const FseTableError = error{
651 MalformedFseTable,
652 MalformedAccuracyLog,
653 RepeatModeFirst,
654 EndOfStream,
655 };
656
657 /// TODO: don't use `@field`
658 fn updateFseTable(
659 self: *Decode,
660 in: *Reader,
661 remaining: *Limit,
662 comptime choice: WhichFse,
663 mode: SequencesSection.Header.Mode,
664 ) !void {
665 const field_name = @tagName(choice);
666 switch (mode) {
667 .predefined => {
668 @field(self, field_name).accuracy_log =
669 @field(zstd.default_accuracy_log, field_name);
670
671 @field(self, field_name).table =
672 @field(Table, "predefined_" ++ field_name);
673 },
674 .rle => {
675 @field(self, field_name).accuracy_log = 0;
676 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
677 @field(self, field_name).table = .{ .rle = try in.takeByte() };
678 },
679 .fse => {
680 const max_table_size = 2048;
681 const peek_len: usize = remaining.minInt(max_table_size);
682 if (in.buffer.len < peek_len) return error.InputBufferUndersize;
683 const limited_buffer = try in.peek(peek_len);
684 var bit_reader: BitReader = .{ .bytes = limited_buffer };
685 const table_size = try Table.decode(
686 &bit_reader,
687 @field(zstd.table_symbol_count_max, field_name),
688 @field(zstd.table_accuracy_log_max, field_name),
689 &@field(self, field_name ++ "_fse_buffer"),
690 );
691 @field(self, field_name).table = .{
692 .fse = (&@field(self, field_name ++ "_fse_buffer"))[0..table_size],
693 };
694 @field(self, field_name).accuracy_log = std.math.log2_int_ceil(usize, table_size);
695 in.toss(bit_reader.index);
696 remaining.* = remaining.subtract(bit_reader.index).?;
697 },
698 .repeat => if (self.fse_tables_undefined) return error.RepeatModeFirst,
699 }
700 }
701
702 const Sequence = struct {
703 literal_length: u32,
704 match_length: u32,
705 offset: u32,
706 };
707
708 fn nextSequence(
709 self: *Decode,
710 bit_reader: *ReverseBitReader,
711 ) error{ InvalidBitStream, EndOfStream }!Sequence {
712 const raw_code = self.getCode(.offset);
713 const offset_code = std.math.cast(u5, raw_code) orelse {
714 return error.InvalidBitStream;
715 };
716 const offset_value = (@as(u32, 1) << offset_code) + try bit_reader.readBitsNoEof(u32, offset_code);
717
718 const match_code = self.getCode(.match);
719 if (match_code >= zstd.match_length_code_table.len)
720 return error.InvalidBitStream;
721 const match = zstd.match_length_code_table[match_code];
722 const match_length = match[0] + try bit_reader.readBitsNoEof(u32, match[1]);
723
724 const literal_code = self.getCode(.literal);
725 if (literal_code >= zstd.literals_length_code_table.len)
726 return error.InvalidBitStream;
727 const literal = zstd.literals_length_code_table[literal_code];
728 const literal_length = literal[0] + try bit_reader.readBitsNoEof(u32, literal[1]);
729
730 const offset = if (offset_value > 3) offset: {
731 const offset = offset_value - 3;
732 self.updateRepeatOffset(offset);
733 break :offset offset;
734 } else offset: {
735 if (literal_length == 0) {
736 if (offset_value == 3) {
737 const offset = self.repeat_offsets[0] - 1;
738 self.updateRepeatOffset(offset);
739 break :offset offset;
740 }
741 break :offset self.useRepeatOffset(offset_value);
742 }
743 break :offset self.useRepeatOffset(offset_value - 1);
744 };
745
746 if (offset == 0) return error.InvalidBitStream;
747
748 return .{
749 .literal_length = literal_length,
750 .match_length = match_length,
751 .offset = offset,
752 };
753 }
754
755 /// Decode one sequence from `bit_reader` into `dest`. Updates FSE states
756 /// if `last_sequence` is `false`. Assumes `prepare` called for the block
757 /// before attempting to decode sequences.
758 fn decodeSequence(
759 decode: *Decode,
760 dest: []u8,
761 write_pos: usize,
762 bit_reader: *ReverseBitReader,
763 ) !usize {
764 const sequence = try decode.nextSequence(bit_reader);
765 const literal_length: usize = sequence.literal_length;
766 const match_length: usize = sequence.match_length;
767 const sequence_length = literal_length + match_length;
768
769 if (sequence_length > dest[write_pos..].len)
770 return error.MalformedSequence;
771
772 const copy_start = std.math.sub(usize, write_pos + sequence.literal_length, sequence.offset) catch
773 return error.MalformedSequence;
774
775 if (decode.literal_written_count + literal_length > decode.literal_header.regenerated_size)
776 return error.MalformedLiteralsLength;
777 var sub_bw: Writer = .fixed(dest[write_pos..]);
778 try decodeLiterals(decode, &sub_bw, literal_length);
779 decode.literal_written_count += literal_length;
780 // This is not a @memmove; it intentionally repeats patterns
781 // caused by iterating one byte at a time.
782 for (
783 dest[write_pos + literal_length ..][0..match_length],
784 dest[copy_start..][0..match_length],
785 ) |*d, s| d.* = s;
786 return sequence_length;
787 }
788
789 fn nextLiteralMultiStream(self: *Decode) error{MissingStartBit}!void {
790 self.literal_stream_index += 1;
791 try self.initLiteralStream(self.literal_streams.four[self.literal_stream_index]);
792 }
793
794 fn initLiteralStream(self: *Decode, bytes: []const u8) error{MissingStartBit}!void {
795 self.literal_stream_reader = try ReverseBitReader.init(bytes);
796 }
797
798 fn isLiteralStreamEmpty(self: *Decode) bool {
799 switch (self.literal_streams) {
800 .one => return self.literal_stream_reader.isEmpty(),
801 .four => return self.literal_stream_index == 3 and self.literal_stream_reader.isEmpty(),
802 }
803 }
804
805 const LiteralBitsError = error{
806 MissingStartBit,
807 UnexpectedEndOfLiteralStream,
808 };
809 fn readLiteralsBits(
810 self: *Decode,
811 bit_count_to_read: u16,
812 ) LiteralBitsError!u16 {
813 return self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch bits: {
814 if (self.literal_streams == .four and self.literal_stream_index < 3) {
815 try self.nextLiteralMultiStream();
816 break :bits self.literal_stream_reader.readBitsNoEof(u16, bit_count_to_read) catch
817 return error.UnexpectedEndOfLiteralStream;
818 } else {
819 return error.UnexpectedEndOfLiteralStream;
820 }
821 };
822 }
823
824 /// Decode `len` bytes of literals into `w`.
825 fn decodeLiterals(d: *Decode, w: *Writer, len: usize) !void {
826 switch (d.literal_header.block_type) {
827 .raw => {
828 try w.writeAll(d.literal_streams.one[d.literal_written_count..][0..len]);
829 },
830 .rle => {
831 try w.splatByteAll(d.literal_streams.one[0], len);
832 },
833 .compressed, .treeless => {
834 const buf = try w.writableSlice(len);
835 const huffman_tree = d.huffman_tree.?;
836 const max_bit_count = huffman_tree.max_bit_count;
837 const starting_bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
838 huffman_tree.nodes[huffman_tree.symbol_count_minus_one].weight,
839 max_bit_count,
840 );
841 var bits_read: u4 = 0;
842 var huffman_tree_index: usize = huffman_tree.symbol_count_minus_one;
843 var bit_count_to_read: u4 = starting_bit_count;
844 for (buf) |*out| {
845 var prefix: u16 = 0;
846 while (true) {
847 const new_bits = try d.readLiteralsBits(bit_count_to_read);
848 prefix <<= bit_count_to_read;
849 prefix |= new_bits;
850 bits_read += bit_count_to_read;
851 const result = try huffman_tree.query(huffman_tree_index, prefix);
852
853 switch (result) {
854 .symbol => |sym| {
855 out.* = sym;
856 bit_count_to_read = starting_bit_count;
857 bits_read = 0;
858 huffman_tree_index = huffman_tree.symbol_count_minus_one;
859 break;
860 },
861 .index => |index| {
862 huffman_tree_index = index;
863 const bit_count = LiteralsSection.HuffmanTree.weightToBitCount(
864 huffman_tree.nodes[index].weight,
865 max_bit_count,
866 );
867 bit_count_to_read = bit_count - bits_read;
868 },
869 }
870 }
871 }
872 },
873 }
874 }
875
876 /// TODO: don't use `@field`
877 fn getCode(self: *Decode, comptime choice: WhichFse) u32 {
878 return switch (@field(self, @tagName(choice)).table) {
879 .rle => |value| value,
880 .fse => |table| table[@field(self, @tagName(choice)).state].symbol,
881 };
882 }
883 };
884 };
885
886 pub const Skippable = struct {
887 pub const magic_min: Magic = @enumFromInt(0x184D2A50);
888 pub const magic_max: Magic = @enumFromInt(0x184D2A5F);
889
890 pub const Header = struct {
891 magic_number: u32,
892 frame_size: u32,
893 };
894 };
895
896 const InitError = error{
897 /// Frame uses a dictionary.
898 DictionaryIdFlagUnsupported,
899 /// Frame does not have a valid window size.
900 WindowSizeUnknown,
901 /// Window size exceeds `window_size_max` or max `usize` value.
902 WindowOversize,
903 /// Frame header indicates a content size exceeding max `usize` value.
904 ContentOversize,
905 };
906
907 /// Validates `frame_header` and returns the associated `Frame`.
908 pub fn init(
909 frame_header: Frame.Zstandard.Header,
910 window_size_max: usize,
911 verify_checksum: bool,
912 ) InitError!Frame {
913 if (frame_header.descriptor.dictionary_id_flag != 0)
914 return error.DictionaryIdFlagUnsupported;
915
916 const window_size_raw = frame_header.windowSize() orelse return error.WindowSizeUnknown;
917 const window_size = if (window_size_raw > window_size_max)
918 return error.WindowOversize
919 else
920 std.math.cast(usize, window_size_raw) orelse return error.WindowOversize;
921
922 const should_compute_checksum =
923 frame_header.descriptor.content_checksum_flag and verify_checksum;
924
925 const content_size = if (frame_header.content_size) |size|
926 std.math.cast(usize, size) orelse return error.ContentOversize
927 else
928 null;
929
930 return .{
931 .hasher_opt = if (should_compute_checksum) std.hash.XxHash64.init(0) else null,
932 .window_size = window_size,
933 .has_checksum = frame_header.descriptor.content_checksum_flag,
934 .block_size_max = @min(zstd.block_size_max, window_size),
935 .content_size = content_size,
936 };
937 }
938};
939
940pub const LiteralsSection = struct {
941 header: Header,
942 huffman_tree: ?HuffmanTree,
943 streams: Streams,
944
945 pub const Streams = union(enum) {
946 one: []const u8,
947 four: [4][]const u8,
948
949 fn decode(size_format: u2, stream_data: []const u8) !Streams {
950 if (size_format == 0) {
951 return .{ .one = stream_data };
952 }
953
954 if (stream_data.len < 6) return error.MalformedLiteralsSection;
955
956 const stream_1_length: usize = std.mem.readInt(u16, stream_data[0..2], .little);
957 const stream_2_length: usize = std.mem.readInt(u16, stream_data[2..4], .little);
958 const stream_3_length: usize = std.mem.readInt(u16, stream_data[4..6], .little);
959
960 const stream_1_start = 6;
961 const stream_2_start = stream_1_start + stream_1_length;
962 const stream_3_start = stream_2_start + stream_2_length;
963 const stream_4_start = stream_3_start + stream_3_length;
964
965 if (stream_data.len < stream_4_start) return error.MalformedLiteralsSection;
966
967 return .{ .four = .{
968 stream_data[stream_1_start .. stream_1_start + stream_1_length],
969 stream_data[stream_2_start .. stream_2_start + stream_2_length],
970 stream_data[stream_3_start .. stream_3_start + stream_3_length],
971 stream_data[stream_4_start..],
972 } };
973 }
974 };
975
976 pub const Header = struct {
977 block_type: BlockType,
978 size_format: u2,
979 regenerated_size: u20,
980 compressed_size: ?u18,
981
982 /// Decode a literals section header.
983 pub fn decode(in: *Reader, remaining: *Limit) !Header {
984 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
985 const byte0 = try in.takeByte();
986 const block_type: BlockType = @enumFromInt(byte0 & 0b11);
987 const size_format: u2 = @intCast((byte0 & 0b1100) >> 2);
988 var regenerated_size: u20 = undefined;
989 var compressed_size: ?u18 = null;
990 switch (block_type) {
991 .raw, .rle => {
992 switch (size_format) {
993 0, 2 => {
994 regenerated_size = byte0 >> 3;
995 },
996 1 => {
997 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
998 regenerated_size = (byte0 >> 4) + (@as(u20, try in.takeByte()) << 4);
999 },
1000 3 => {
1001 remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
1002 regenerated_size = (byte0 >> 4) +
1003 (@as(u20, try in.takeByte()) << 4) +
1004 (@as(u20, try in.takeByte()) << 12);
1005 },
1006 }
1007 },
1008 .compressed, .treeless => {
1009 remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
1010 const byte1 = try in.takeByte();
1011 const byte2 = try in.takeByte();
1012 switch (size_format) {
1013 0, 1 => {
1014 regenerated_size = (byte0 >> 4) + ((@as(u20, byte1) & 0b00111111) << 4);
1015 compressed_size = ((byte1 & 0b11000000) >> 6) + (@as(u18, byte2) << 2);
1016 },
1017 2 => {
1018 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
1019 const byte3 = try in.takeByte();
1020 regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00000011) << 12);
1021 compressed_size = ((byte2 & 0b11111100) >> 2) + (@as(u18, byte3) << 6);
1022 },
1023 3 => {
1024 remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
1025 const byte3 = try in.takeByte();
1026 const byte4 = try in.takeByte();
1027 regenerated_size = (byte0 >> 4) + (@as(u20, byte1) << 4) + ((@as(u20, byte2) & 0b00111111) << 12);
1028 compressed_size = ((byte2 & 0b11000000) >> 6) + (@as(u18, byte3) << 2) + (@as(u18, byte4) << 10);
1029 },
1030 }
1031 },
1032 }
1033 return .{
1034 .block_type = block_type,
1035 .size_format = size_format,
1036 .regenerated_size = regenerated_size,
1037 .compressed_size = compressed_size,
1038 };
1039 }
1040 };
1041
1042 pub const BlockType = enum(u2) {
1043 raw,
1044 rle,
1045 compressed,
1046 treeless,
1047 };
1048
1049 pub const HuffmanTree = struct {
1050 max_bit_count: u4,
1051 symbol_count_minus_one: u8,
1052 nodes: [256]PrefixedSymbol,
1053
1054 pub const PrefixedSymbol = struct {
1055 symbol: u8,
1056 prefix: u16,
1057 weight: u4,
1058 };
1059
1060 pub const Result = union(enum) {
1061 symbol: u8,
1062 index: usize,
1063 };
1064
1065 pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{HuffmanTreeIncomplete}!Result {
1066 var node = self.nodes[index];
1067 const weight = node.weight;
1068 var i: usize = index;
1069 while (node.weight == weight) {
1070 if (node.prefix == prefix) return .{ .symbol = node.symbol };
1071 if (i == 0) return error.HuffmanTreeIncomplete;
1072 i -= 1;
1073 node = self.nodes[i];
1074 }
1075 return .{ .index = i };
1076 }
1077
1078 pub fn weightToBitCount(weight: u4, max_bit_count: u4) u4 {
1079 return if (weight == 0) 0 else ((max_bit_count + 1) - weight);
1080 }
1081
1082 pub const DecodeError = Reader.Error || error{
1083 MalformedHuffmanTree,
1084 MalformedFseTable,
1085 MalformedAccuracyLog,
1086 EndOfStream,
1087 MissingStartBit,
1088 };
1089
1090 pub fn decode(in: *Reader, remaining: *Limit) HuffmanTree.DecodeError!HuffmanTree {
1091 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
1092 const header = try in.takeByte();
1093 if (header < 128) {
1094 return decodeFse(in, remaining, header);
1095 } else {
1096 return decodeDirect(in, remaining, header - 127);
1097 }
1098 }
1099
1100 fn decodeDirect(
1101 in: *Reader,
1102 remaining: *Limit,
1103 encoded_symbol_count: usize,
1104 ) HuffmanTree.DecodeError!HuffmanTree {
1105 var weights: [256]u4 = undefined;
1106 const weights_byte_count = (encoded_symbol_count + 1) / 2;
1107 remaining.* = remaining.subtract(weights_byte_count) orelse return error.EndOfStream;
1108 for (0..weights_byte_count) |i| {
1109 const byte = try in.takeByte();
1110 weights[2 * i] = @as(u4, @intCast(byte >> 4));
1111 weights[2 * i + 1] = @as(u4, @intCast(byte & 0xF));
1112 }
1113 const symbol_count = encoded_symbol_count + 1;
1114 return build(&weights, symbol_count);
1115 }
1116
1117 fn decodeFse(
1118 in: *Reader,
1119 remaining: *Limit,
1120 compressed_size: usize,
1121 ) HuffmanTree.DecodeError!HuffmanTree {
1122 var weights: [256]u4 = undefined;
1123 remaining.* = remaining.subtract(compressed_size) orelse return error.EndOfStream;
1124 const compressed_buffer = try in.take(compressed_size);
1125 var bit_reader: BitReader = .{ .bytes = compressed_buffer };
1126 var entries: [1 << 6]Table.Fse = undefined;
1127 const table_size = try Table.decode(&bit_reader, 256, 6, &entries);
1128 const accuracy_log = std.math.log2_int_ceil(usize, table_size);
1129 const remaining_buffer = bit_reader.bytes[bit_reader.index..];
1130 const symbol_count = try assignWeights(remaining_buffer, accuracy_log, &entries, &weights);
1131 return build(&weights, symbol_count);
1132 }
1133
1134 fn assignWeights(
1135 huff_bits_buffer: []const u8,
1136 accuracy_log: u16,
1137 entries: *[1 << 6]Table.Fse,
1138 weights: *[256]u4,
1139 ) !usize {
1140 var huff_bits = try ReverseBitReader.init(huff_bits_buffer);
1141
1142 var i: usize = 0;
1143 var even_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log);
1144 var odd_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log);
1145
1146 while (i < 254) {
1147 const even_data = entries[even_state];
1148 var read_bits: u16 = 0;
1149 const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable;
1150 weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree;
1151 i += 1;
1152 if (read_bits < even_data.bits) {
1153 weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree;
1154 i += 1;
1155 break;
1156 }
1157 even_state = even_data.baseline + even_bits;
1158
1159 read_bits = 0;
1160 const odd_data = entries[odd_state];
1161 const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable;
1162 weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree;
1163 i += 1;
1164 if (read_bits < odd_data.bits) {
1165 if (i == 255) return error.MalformedHuffmanTree;
1166 weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree;
1167 i += 1;
1168 break;
1169 }
1170 odd_state = odd_data.baseline + odd_bits;
1171 } else return error.MalformedHuffmanTree;
1172
1173 if (!huff_bits.isEmpty()) {
1174 return error.MalformedHuffmanTree;
1175 }
1176
1177 return i + 1; // stream contains all but the last symbol
1178 }
1179
1180 fn assignSymbols(weight_sorted_prefixed_symbols: []PrefixedSymbol, weights: [256]u4) usize {
1181 for (0..weight_sorted_prefixed_symbols.len) |i| {
1182 weight_sorted_prefixed_symbols[i] = .{
1183 .symbol = @as(u8, @intCast(i)),
1184 .weight = undefined,
1185 .prefix = undefined,
1186 };
1187 }
1188
1189 std.mem.sort(
1190 PrefixedSymbol,
1191 weight_sorted_prefixed_symbols,
1192 weights,
1193 lessThanByWeight,
1194 );
1195
1196 var prefix: u16 = 0;
1197 var prefixed_symbol_count: usize = 0;
1198 var sorted_index: usize = 0;
1199 const symbol_count = weight_sorted_prefixed_symbols.len;
1200 while (sorted_index < symbol_count) {
1201 var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
1202 const weight = weights[symbol];
1203 if (weight == 0) {
1204 sorted_index += 1;
1205 continue;
1206 }
1207
1208 while (sorted_index < symbol_count) : ({
1209 sorted_index += 1;
1210 prefixed_symbol_count += 1;
1211 prefix += 1;
1212 }) {
1213 symbol = weight_sorted_prefixed_symbols[sorted_index].symbol;
1214 if (weights[symbol] != weight) {
1215 prefix = ((prefix - 1) >> (weights[symbol] - weight)) + 1;
1216 break;
1217 }
1218 weight_sorted_prefixed_symbols[prefixed_symbol_count].symbol = symbol;
1219 weight_sorted_prefixed_symbols[prefixed_symbol_count].prefix = prefix;
1220 weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight;
1221 }
1222 }
1223 return prefixed_symbol_count;
1224 }
1225
1226 fn build(weights: *[256]u4, symbol_count: usize) error{MalformedHuffmanTree}!HuffmanTree {
1227 var weight_power_sum_big: u32 = 0;
1228 for (weights[0 .. symbol_count - 1]) |value| {
1229 weight_power_sum_big += (@as(u16, 1) << value) >> 1;
1230 }
1231 if (weight_power_sum_big >= 1 << 11) return error.MalformedHuffmanTree;
1232 const weight_power_sum = @as(u16, @intCast(weight_power_sum_big));
1233
1234 // advance to next power of two (even if weight_power_sum is a power of 2)
1235 // TODO: is it valid to have weight_power_sum == 0?
1236 const max_number_of_bits = if (weight_power_sum == 0) 1 else std.math.log2_int(u16, weight_power_sum) + 1;
1237 const next_power_of_two = @as(u16, 1) << max_number_of_bits;
1238 weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1;
1239
1240 var weight_sorted_prefixed_symbols: [256]PrefixedSymbol = undefined;
1241 const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*);
1242 const tree: HuffmanTree = .{
1243 .max_bit_count = max_number_of_bits,
1244 .symbol_count_minus_one = @as(u8, @intCast(prefixed_symbol_count - 1)),
1245 .nodes = weight_sorted_prefixed_symbols,
1246 };
1247 return tree;
1248 }
1249
1250 fn lessThanByWeight(
1251 weights: [256]u4,
1252 lhs: PrefixedSymbol,
1253 rhs: PrefixedSymbol,
1254 ) bool {
1255 // NOTE: this function relies on the use of a stable sorting algorithm,
1256 // otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs;
1257 // should be added
1258 return weights[lhs.symbol] < weights[rhs.symbol];
1259 }
1260 };
1261
1262 pub const StreamCount = enum { one, four };
1263 pub fn streamCount(size_format: u2, block_type: BlockType) StreamCount {
1264 return switch (block_type) {
1265 .raw, .rle => .one,
1266 .compressed, .treeless => if (size_format == 0) .one else .four,
1267 };
1268 }
1269
1270 pub const DecodeError = error{
1271 /// Invalid header.
1272 MalformedLiteralsHeader,
1273 /// Decoding errors.
1274 MalformedLiteralsSection,
1275 /// Compressed literals have invalid accuracy.
1276 MalformedAccuracyLog,
1277 /// Compressed literals have invalid FSE table.
1278 MalformedFseTable,
1279 /// Failed decoding a Huffamn tree.
1280 MalformedHuffmanTree,
1281 /// Not enough bytes to complete the section.
1282 EndOfStream,
1283 ReadFailed,
1284 MissingStartBit,
1285 };
1286
1287 pub fn decode(in: *Reader, remaining: *Limit, buffer: []u8) DecodeError!LiteralsSection {
1288 const header = try Header.decode(in, remaining);
1289 switch (header.block_type) {
1290 .raw => {
1291 if (buffer.len < header.regenerated_size) return error.MalformedLiteralsSection;
1292 remaining.* = remaining.subtract(header.regenerated_size) orelse return error.EndOfStream;
1293 try in.readSliceAll(buffer[0..header.regenerated_size]);
1294 return .{
1295 .header = header,
1296 .huffman_tree = null,
1297 .streams = .{ .one = buffer },
1298 };
1299 },
1300 .rle => {
1301 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
1302 buffer[0] = try in.takeByte();
1303 return .{
1304 .header = header,
1305 .huffman_tree = null,
1306 .streams = .{ .one = buffer[0..1] },
1307 };
1308 },
1309 .compressed, .treeless => {
1310 const before_remaining = remaining.*;
1311 const huffman_tree = if (header.block_type == .compressed)
1312 try HuffmanTree.decode(in, remaining)
1313 else
1314 null;
1315 const huffman_tree_size = @intFromEnum(before_remaining) - @intFromEnum(remaining.*);
1316 const total_streams_size = std.math.sub(usize, header.compressed_size.?, huffman_tree_size) catch
1317 return error.MalformedLiteralsSection;
1318 if (total_streams_size > buffer.len) return error.MalformedLiteralsSection;
1319 remaining.* = remaining.subtract(total_streams_size) orelse return error.EndOfStream;
1320 try in.readSliceAll(buffer[0..total_streams_size]);
1321 const stream_data = buffer[0..total_streams_size];
1322 const streams = try Streams.decode(header.size_format, stream_data);
1323 return .{
1324 .header = header,
1325 .huffman_tree = huffman_tree,
1326 .streams = streams,
1327 };
1328 },
1329 }
1330 }
1331};
1332
1333pub const SequencesSection = struct {
1334 header: Header,
1335 literals_length_table: Table,
1336 offset_table: Table,
1337 match_length_table: Table,
1338
1339 pub const Header = struct {
1340 sequence_count: u24,
1341 match_lengths: Mode,
1342 offsets: Mode,
1343 literal_lengths: Mode,
1344
1345 pub const Mode = enum(u2) {
1346 predefined,
1347 rle,
1348 fse,
1349 repeat,
1350 };
1351
1352 pub const DecodeError = error{
1353 ReservedBitSet,
1354 EndOfStream,
1355 ReadFailed,
1356 };
1357
1358 pub fn decode(in: *Reader, remaining: *Limit) DecodeError!Header {
1359 var sequence_count: u24 = undefined;
1360
1361 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
1362 const byte0 = try in.takeByte();
1363 if (byte0 == 0) {
1364 return .{
1365 .sequence_count = 0,
1366 .offsets = undefined,
1367 .match_lengths = undefined,
1368 .literal_lengths = undefined,
1369 };
1370 } else if (byte0 < 128) {
1371 remaining.* = remaining.subtract(1) orelse return error.EndOfStream;
1372 sequence_count = byte0;
1373 } else if (byte0 < 255) {
1374 remaining.* = remaining.subtract(2) orelse return error.EndOfStream;
1375 sequence_count = (@as(u24, (byte0 - 128)) << 8) + try in.takeByte();
1376 } else {
1377 remaining.* = remaining.subtract(3) orelse return error.EndOfStream;
1378 sequence_count = (try in.takeByte()) + (@as(u24, try in.takeByte()) << 8) + 0x7F00;
1379 }
1380
1381 const compression_modes = try in.takeByte();
1382
1383 const matches_mode: Header.Mode = @enumFromInt((compression_modes & 0b00001100) >> 2);
1384 const offsets_mode: Header.Mode = @enumFromInt((compression_modes & 0b00110000) >> 4);
1385 const literal_mode: Header.Mode = @enumFromInt((compression_modes & 0b11000000) >> 6);
1386 if (compression_modes & 0b11 != 0) return error.ReservedBitSet;
1387
1388 return .{
1389 .sequence_count = sequence_count,
1390 .offsets = offsets_mode,
1391 .match_lengths = matches_mode,
1392 .literal_lengths = literal_mode,
1393 };
1394 }
1395 };
1396};
1397
1398pub const Table = union(enum) {
1399 fse: []const Fse,
1400 rle: u8,
1401
1402 pub const Fse = struct {
1403 symbol: u8,
1404 baseline: u16,
1405 bits: u8,
1406 };
1407
1408 pub fn decode(
1409 bit_reader: *BitReader,
1410 expected_symbol_count: usize,
1411 max_accuracy_log: u4,
1412 entries: []Table.Fse,
1413 ) !usize {
1414 const accuracy_log_biased = try bit_reader.readBitsNoEof(u4, 4);
1415 if (accuracy_log_biased > max_accuracy_log -| 5) return error.MalformedAccuracyLog;
1416 const accuracy_log = accuracy_log_biased + 5;
1417
1418 var values: [256]u16 = undefined;
1419 var value_count: usize = 0;
1420
1421 const total_probability = @as(u16, 1) << accuracy_log;
1422 var accumulated_probability: u16 = 0;
1423
1424 while (accumulated_probability < total_probability) {
1425 // WARNING: The RFC is poorly worded, and would suggest std.math.log2_int_ceil is correct here,
1426 // but power of two (remaining probabilities + 1) need max bits set to 1 more.
1427 const max_bits = std.math.log2_int(u16, total_probability - accumulated_probability + 1) + 1;
1428 const small = try bit_reader.readBitsNoEof(u16, max_bits - 1);
1429
1430 const cutoff = (@as(u16, 1) << max_bits) - 1 - (total_probability - accumulated_probability + 1);
1431
1432 const value = if (small < cutoff)
1433 small
1434 else value: {
1435 const value_read = small + (try bit_reader.readBitsNoEof(u16, 1) << (max_bits - 1));
1436 break :value if (value_read < @as(u16, 1) << (max_bits - 1))
1437 value_read
1438 else
1439 value_read - cutoff;
1440 };
1441
1442 accumulated_probability += if (value != 0) value - 1 else 1;
1443
1444 values[value_count] = value;
1445 value_count += 1;
1446
1447 if (value == 1) {
1448 while (true) {
1449 const repeat_flag = try bit_reader.readBitsNoEof(u2, 2);
1450 if (repeat_flag + value_count > 256) return error.MalformedFseTable;
1451 for (0..repeat_flag) |_| {
1452 values[value_count] = 1;
1453 value_count += 1;
1454 }
1455 if (repeat_flag < 3) break;
1456 }
1457 }
1458 if (value_count == 256) break;
1459 }
1460 bit_reader.alignToByte();
1461
1462 if (value_count < 2) return error.MalformedFseTable;
1463 if (accumulated_probability != total_probability) return error.MalformedFseTable;
1464 if (value_count > expected_symbol_count) return error.MalformedFseTable;
1465
1466 const table_size = total_probability;
1467
1468 try build(values[0..value_count], entries[0..table_size]);
1469 return table_size;
1470 }
1471
1472 pub fn build(values: []const u16, entries: []Table.Fse) !void {
1473 const total_probability = @as(u16, @intCast(entries.len));
1474 const accuracy_log = std.math.log2_int(u16, total_probability);
1475 assert(total_probability <= 1 << 9);
1476
1477 var less_than_one_count: usize = 0;
1478 for (values, 0..) |value, i| {
1479 if (value == 0) {
1480 entries[entries.len - 1 - less_than_one_count] = Table.Fse{
1481 .symbol = @as(u8, @intCast(i)),
1482 .baseline = 0,
1483 .bits = accuracy_log,
1484 };
1485 less_than_one_count += 1;
1486 }
1487 }
1488
1489 var position: usize = 0;
1490 var temp_states: [1 << 9]u16 = undefined;
1491 for (values, 0..) |value, symbol| {
1492 if (value == 0 or value == 1) continue;
1493 const probability = value - 1;
1494
1495 const state_share_dividend = std.math.ceilPowerOfTwo(u16, probability) catch
1496 return error.MalformedFseTable;
1497 const share_size = @divExact(total_probability, state_share_dividend);
1498 const double_state_count = state_share_dividend - probability;
1499 const single_state_count = probability - double_state_count;
1500 const share_size_log = std.math.log2_int(u16, share_size);
1501
1502 for (0..probability) |i| {
1503 temp_states[i] = @as(u16, @intCast(position));
1504 position += (entries.len >> 1) + (entries.len >> 3) + 3;
1505 position &= entries.len - 1;
1506 while (position >= entries.len - less_than_one_count) {
1507 position += (entries.len >> 1) + (entries.len >> 3) + 3;
1508 position &= entries.len - 1;
1509 }
1510 }
1511 std.mem.sort(u16, temp_states[0..probability], {}, std.sort.asc(u16));
1512 for (0..probability) |i| {
1513 entries[temp_states[i]] = if (i < double_state_count) Table.Fse{
1514 .symbol = @as(u8, @intCast(symbol)),
1515 .bits = share_size_log + 1,
1516 .baseline = single_state_count * share_size + @as(u16, @intCast(i)) * 2 * share_size,
1517 } else Table.Fse{
1518 .symbol = @as(u8, @intCast(symbol)),
1519 .bits = share_size_log,
1520 .baseline = (@as(u16, @intCast(i)) - double_state_count) * share_size,
1521 };
1522 }
1523 }
1524 }
1525
1526 test build {
1527 const literals_length_default_values = [36]u16{
1528 5, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2,
1529 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 2, 2, 2, 2, 2,
1530 0, 0, 0, 0,
1531 };
1532
1533 const match_lengths_default_values = [53]u16{
1534 2, 5, 4, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
1535 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
1536 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0,
1537 0, 0, 0, 0, 0,
1538 };
1539
1540 const offset_codes_default_values = [29]u16{
1541 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2,
1542 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0,
1543 };
1544
1545 var entries: [64]Table.Fse = undefined;
1546 try build(&literals_length_default_values, &entries);
1547 try std.testing.expectEqualSlices(Table.Fse, Table.predefined_literal.fse, &entries);
1548
1549 try build(&match_lengths_default_values, &entries);
1550 try std.testing.expectEqualSlices(Table.Fse, Table.predefined_match.fse, &entries);
1551
1552 try build(&offset_codes_default_values, entries[0..32]);
1553 try std.testing.expectEqualSlices(Table.Fse, Table.predefined_offset.fse, entries[0..32]);
1554 }
1555
1556 pub const predefined_literal: Table = .{
1557 .fse = &[64]Table.Fse{
1558 .{ .symbol = 0, .bits = 4, .baseline = 0 },
1559 .{ .symbol = 0, .bits = 4, .baseline = 16 },
1560 .{ .symbol = 1, .bits = 5, .baseline = 32 },
1561 .{ .symbol = 3, .bits = 5, .baseline = 0 },
1562 .{ .symbol = 4, .bits = 5, .baseline = 0 },
1563 .{ .symbol = 6, .bits = 5, .baseline = 0 },
1564 .{ .symbol = 7, .bits = 5, .baseline = 0 },
1565 .{ .symbol = 9, .bits = 5, .baseline = 0 },
1566 .{ .symbol = 10, .bits = 5, .baseline = 0 },
1567 .{ .symbol = 12, .bits = 5, .baseline = 0 },
1568 .{ .symbol = 14, .bits = 6, .baseline = 0 },
1569 .{ .symbol = 16, .bits = 5, .baseline = 0 },
1570 .{ .symbol = 18, .bits = 5, .baseline = 0 },
1571 .{ .symbol = 19, .bits = 5, .baseline = 0 },
1572 .{ .symbol = 21, .bits = 5, .baseline = 0 },
1573 .{ .symbol = 22, .bits = 5, .baseline = 0 },
1574 .{ .symbol = 24, .bits = 5, .baseline = 0 },
1575 .{ .symbol = 25, .bits = 5, .baseline = 32 },
1576 .{ .symbol = 26, .bits = 5, .baseline = 0 },
1577 .{ .symbol = 27, .bits = 6, .baseline = 0 },
1578 .{ .symbol = 29, .bits = 6, .baseline = 0 },
1579 .{ .symbol = 31, .bits = 6, .baseline = 0 },
1580 .{ .symbol = 0, .bits = 4, .baseline = 32 },
1581 .{ .symbol = 1, .bits = 4, .baseline = 0 },
1582 .{ .symbol = 2, .bits = 5, .baseline = 0 },
1583 .{ .symbol = 4, .bits = 5, .baseline = 32 },
1584 .{ .symbol = 5, .bits = 5, .baseline = 0 },
1585 .{ .symbol = 7, .bits = 5, .baseline = 32 },
1586 .{ .symbol = 8, .bits = 5, .baseline = 0 },
1587 .{ .symbol = 10, .bits = 5, .baseline = 32 },
1588 .{ .symbol = 11, .bits = 5, .baseline = 0 },
1589 .{ .symbol = 13, .bits = 6, .baseline = 0 },
1590 .{ .symbol = 16, .bits = 5, .baseline = 32 },
1591 .{ .symbol = 17, .bits = 5, .baseline = 0 },
1592 .{ .symbol = 19, .bits = 5, .baseline = 32 },
1593 .{ .symbol = 20, .bits = 5, .baseline = 0 },
1594 .{ .symbol = 22, .bits = 5, .baseline = 32 },
1595 .{ .symbol = 23, .bits = 5, .baseline = 0 },
1596 .{ .symbol = 25, .bits = 4, .baseline = 0 },
1597 .{ .symbol = 25, .bits = 4, .baseline = 16 },
1598 .{ .symbol = 26, .bits = 5, .baseline = 32 },
1599 .{ .symbol = 28, .bits = 6, .baseline = 0 },
1600 .{ .symbol = 30, .bits = 6, .baseline = 0 },
1601 .{ .symbol = 0, .bits = 4, .baseline = 48 },
1602 .{ .symbol = 1, .bits = 4, .baseline = 16 },
1603 .{ .symbol = 2, .bits = 5, .baseline = 32 },
1604 .{ .symbol = 3, .bits = 5, .baseline = 32 },
1605 .{ .symbol = 5, .bits = 5, .baseline = 32 },
1606 .{ .symbol = 6, .bits = 5, .baseline = 32 },
1607 .{ .symbol = 8, .bits = 5, .baseline = 32 },
1608 .{ .symbol = 9, .bits = 5, .baseline = 32 },
1609 .{ .symbol = 11, .bits = 5, .baseline = 32 },
1610 .{ .symbol = 12, .bits = 5, .baseline = 32 },
1611 .{ .symbol = 15, .bits = 6, .baseline = 0 },
1612 .{ .symbol = 17, .bits = 5, .baseline = 32 },
1613 .{ .symbol = 18, .bits = 5, .baseline = 32 },
1614 .{ .symbol = 20, .bits = 5, .baseline = 32 },
1615 .{ .symbol = 21, .bits = 5, .baseline = 32 },
1616 .{ .symbol = 23, .bits = 5, .baseline = 32 },
1617 .{ .symbol = 24, .bits = 5, .baseline = 32 },
1618 .{ .symbol = 35, .bits = 6, .baseline = 0 },
1619 .{ .symbol = 34, .bits = 6, .baseline = 0 },
1620 .{ .symbol = 33, .bits = 6, .baseline = 0 },
1621 .{ .symbol = 32, .bits = 6, .baseline = 0 },
1622 },
1623 };
1624
1625 pub const predefined_match: Table = .{
1626 .fse = &[64]Table.Fse{
1627 .{ .symbol = 0, .bits = 6, .baseline = 0 },
1628 .{ .symbol = 1, .bits = 4, .baseline = 0 },
1629 .{ .symbol = 2, .bits = 5, .baseline = 32 },
1630 .{ .symbol = 3, .bits = 5, .baseline = 0 },
1631 .{ .symbol = 5, .bits = 5, .baseline = 0 },
1632 .{ .symbol = 6, .bits = 5, .baseline = 0 },
1633 .{ .symbol = 8, .bits = 5, .baseline = 0 },
1634 .{ .symbol = 10, .bits = 6, .baseline = 0 },
1635 .{ .symbol = 13, .bits = 6, .baseline = 0 },
1636 .{ .symbol = 16, .bits = 6, .baseline = 0 },
1637 .{ .symbol = 19, .bits = 6, .baseline = 0 },
1638 .{ .symbol = 22, .bits = 6, .baseline = 0 },
1639 .{ .symbol = 25, .bits = 6, .baseline = 0 },
1640 .{ .symbol = 28, .bits = 6, .baseline = 0 },
1641 .{ .symbol = 31, .bits = 6, .baseline = 0 },
1642 .{ .symbol = 33, .bits = 6, .baseline = 0 },
1643 .{ .symbol = 35, .bits = 6, .baseline = 0 },
1644 .{ .symbol = 37, .bits = 6, .baseline = 0 },
1645 .{ .symbol = 39, .bits = 6, .baseline = 0 },
1646 .{ .symbol = 41, .bits = 6, .baseline = 0 },
1647 .{ .symbol = 43, .bits = 6, .baseline = 0 },
1648 .{ .symbol = 45, .bits = 6, .baseline = 0 },
1649 .{ .symbol = 1, .bits = 4, .baseline = 16 },
1650 .{ .symbol = 2, .bits = 4, .baseline = 0 },
1651 .{ .symbol = 3, .bits = 5, .baseline = 32 },
1652 .{ .symbol = 4, .bits = 5, .baseline = 0 },
1653 .{ .symbol = 6, .bits = 5, .baseline = 32 },
1654 .{ .symbol = 7, .bits = 5, .baseline = 0 },
1655 .{ .symbol = 9, .bits = 6, .baseline = 0 },
1656 .{ .symbol = 12, .bits = 6, .baseline = 0 },
1657 .{ .symbol = 15, .bits = 6, .baseline = 0 },
1658 .{ .symbol = 18, .bits = 6, .baseline = 0 },
1659 .{ .symbol = 21, .bits = 6, .baseline = 0 },
1660 .{ .symbol = 24, .bits = 6, .baseline = 0 },
1661 .{ .symbol = 27, .bits = 6, .baseline = 0 },
1662 .{ .symbol = 30, .bits = 6, .baseline = 0 },
1663 .{ .symbol = 32, .bits = 6, .baseline = 0 },
1664 .{ .symbol = 34, .bits = 6, .baseline = 0 },
1665 .{ .symbol = 36, .bits = 6, .baseline = 0 },
1666 .{ .symbol = 38, .bits = 6, .baseline = 0 },
1667 .{ .symbol = 40, .bits = 6, .baseline = 0 },
1668 .{ .symbol = 42, .bits = 6, .baseline = 0 },
1669 .{ .symbol = 44, .bits = 6, .baseline = 0 },
1670 .{ .symbol = 1, .bits = 4, .baseline = 32 },
1671 .{ .symbol = 1, .bits = 4, .baseline = 48 },
1672 .{ .symbol = 2, .bits = 4, .baseline = 16 },
1673 .{ .symbol = 4, .bits = 5, .baseline = 32 },
1674 .{ .symbol = 5, .bits = 5, .baseline = 32 },
1675 .{ .symbol = 7, .bits = 5, .baseline = 32 },
1676 .{ .symbol = 8, .bits = 5, .baseline = 32 },
1677 .{ .symbol = 11, .bits = 6, .baseline = 0 },
1678 .{ .symbol = 14, .bits = 6, .baseline = 0 },
1679 .{ .symbol = 17, .bits = 6, .baseline = 0 },
1680 .{ .symbol = 20, .bits = 6, .baseline = 0 },
1681 .{ .symbol = 23, .bits = 6, .baseline = 0 },
1682 .{ .symbol = 26, .bits = 6, .baseline = 0 },
1683 .{ .symbol = 29, .bits = 6, .baseline = 0 },
1684 .{ .symbol = 52, .bits = 6, .baseline = 0 },
1685 .{ .symbol = 51, .bits = 6, .baseline = 0 },
1686 .{ .symbol = 50, .bits = 6, .baseline = 0 },
1687 .{ .symbol = 49, .bits = 6, .baseline = 0 },
1688 .{ .symbol = 48, .bits = 6, .baseline = 0 },
1689 .{ .symbol = 47, .bits = 6, .baseline = 0 },
1690 .{ .symbol = 46, .bits = 6, .baseline = 0 },
1691 },
1692 };
1693
1694 pub const predefined_offset: Table = .{
1695 .fse = &[32]Table.Fse{
1696 .{ .symbol = 0, .bits = 5, .baseline = 0 },
1697 .{ .symbol = 6, .bits = 4, .baseline = 0 },
1698 .{ .symbol = 9, .bits = 5, .baseline = 0 },
1699 .{ .symbol = 15, .bits = 5, .baseline = 0 },
1700 .{ .symbol = 21, .bits = 5, .baseline = 0 },
1701 .{ .symbol = 3, .bits = 5, .baseline = 0 },
1702 .{ .symbol = 7, .bits = 4, .baseline = 0 },
1703 .{ .symbol = 12, .bits = 5, .baseline = 0 },
1704 .{ .symbol = 18, .bits = 5, .baseline = 0 },
1705 .{ .symbol = 23, .bits = 5, .baseline = 0 },
1706 .{ .symbol = 5, .bits = 5, .baseline = 0 },
1707 .{ .symbol = 8, .bits = 4, .baseline = 0 },
1708 .{ .symbol = 14, .bits = 5, .baseline = 0 },
1709 .{ .symbol = 20, .bits = 5, .baseline = 0 },
1710 .{ .symbol = 2, .bits = 5, .baseline = 0 },
1711 .{ .symbol = 7, .bits = 4, .baseline = 16 },
1712 .{ .symbol = 11, .bits = 5, .baseline = 0 },
1713 .{ .symbol = 17, .bits = 5, .baseline = 0 },
1714 .{ .symbol = 22, .bits = 5, .baseline = 0 },
1715 .{ .symbol = 4, .bits = 5, .baseline = 0 },
1716 .{ .symbol = 8, .bits = 4, .baseline = 16 },
1717 .{ .symbol = 13, .bits = 5, .baseline = 0 },
1718 .{ .symbol = 19, .bits = 5, .baseline = 0 },
1719 .{ .symbol = 1, .bits = 5, .baseline = 0 },
1720 .{ .symbol = 6, .bits = 4, .baseline = 16 },
1721 .{ .symbol = 10, .bits = 5, .baseline = 0 },
1722 .{ .symbol = 16, .bits = 5, .baseline = 0 },
1723 .{ .symbol = 28, .bits = 5, .baseline = 0 },
1724 .{ .symbol = 27, .bits = 5, .baseline = 0 },
1725 .{ .symbol = 26, .bits = 5, .baseline = 0 },
1726 .{ .symbol = 25, .bits = 5, .baseline = 0 },
1727 .{ .symbol = 24, .bits = 5, .baseline = 0 },
1728 },
1729 };
1730};
1731
1732const low_bit_mask = [9]u8{
1733 0b00000000,
1734 0b00000001,
1735 0b00000011,
1736 0b00000111,
1737 0b00001111,
1738 0b00011111,
1739 0b00111111,
1740 0b01111111,
1741 0b11111111,
1742};
1743
1744fn Bits(comptime T: type) type {
1745 return struct { T, u16 };
1746}
1747
1748/// For reading the reversed bit streams used to encode FSE compressed data.
1749const ReverseBitReader = struct {
1750 bytes: []const u8,
1751 remaining: usize,
1752 bits: u8,
1753 count: u4,
1754
1755 fn init(bytes: []const u8) error{MissingStartBit}!ReverseBitReader {
1756 var result: ReverseBitReader = .{
1757 .bytes = bytes,
1758 .remaining = bytes.len,
1759 .bits = 0,
1760 .count = 0,
1761 };
1762 if (bytes.len == 0) return result;
1763 for (0..8) |_| if (0 != (result.readBitsNoEof(u1, 1) catch unreachable)) return result;
1764 return error.MissingStartBit;
1765 }
1766
1767 fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
1768 const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
1769 return .{
1770 @bitCast(@as(UT, @intCast(out))),
1771 num,
1772 };
1773 }
1774
1775 fn readBitsNoEof(self: *ReverseBitReader, comptime T: type, num: u16) error{EndOfStream}!T {
1776 const b, const c = try self.readBitsTuple(T, num);
1777 if (c < num) return error.EndOfStream;
1778 return b;
1779 }
1780
1781 fn readBits(self: *ReverseBitReader, comptime T: type, num: u16, out_bits: *u16) !T {
1782 const b, const c = try self.readBitsTuple(T, num);
1783 out_bits.* = c;
1784 return b;
1785 }
1786
1787 fn readBitsTuple(self: *ReverseBitReader, comptime T: type, num: u16) !Bits(T) {
1788 const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
1789 const U = if (@bitSizeOf(T) < 8) u8 else UT;
1790
1791 if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
1792
1793 var out_count: u16 = self.count;
1794 var out: U = self.removeBits(self.count);
1795
1796 const full_bytes_left = (num - out_count) / 8;
1797
1798 for (0..full_bytes_left) |_| {
1799 const byte = takeByte(self) catch |err| switch (err) {
1800 error.EndOfStream => return initBits(T, out, out_count),
1801 };
1802 if (U == u8) out = 0 else out <<= 8;
1803 out |= byte;
1804 out_count += 8;
1805 }
1806
1807 const bits_left = num - out_count;
1808 const keep = 8 - bits_left;
1809
1810 if (bits_left == 0) return initBits(T, out, out_count);
1811
1812 const final_byte = takeByte(self) catch |err| switch (err) {
1813 error.EndOfStream => return initBits(T, out, out_count),
1814 };
1815
1816 out <<= @intCast(bits_left);
1817 out |= final_byte >> @intCast(keep);
1818 self.bits = final_byte & low_bit_mask[keep];
1819
1820 self.count = @intCast(keep);
1821 return initBits(T, out, num);
1822 }
1823
1824 fn takeByte(rbr: *ReverseBitReader) error{EndOfStream}!u8 {
1825 if (rbr.remaining == 0) return error.EndOfStream;
1826 rbr.remaining -= 1;
1827 return rbr.bytes[rbr.remaining];
1828 }
1829
1830 fn isEmpty(self: *const ReverseBitReader) bool {
1831 return self.remaining == 0 and self.count == 0;
1832 }
1833
1834 fn removeBits(self: *ReverseBitReader, num: u4) u8 {
1835 if (num == 8) {
1836 self.count = 0;
1837 return self.bits;
1838 }
1839
1840 const keep = self.count - num;
1841 const bits = self.bits >> @intCast(keep);
1842 self.bits &= low_bit_mask[keep];
1843
1844 self.count = keep;
1845 return bits;
1846 }
1847};
1848
1849const BitReader = struct {
1850 bytes: []const u8,
1851 index: usize = 0,
1852 bits: u8 = 0,
1853 count: u4 = 0,
1854
1855 fn initBits(comptime T: type, out: anytype, num: u16) Bits(T) {
1856 const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
1857 return .{
1858 @bitCast(@as(UT, @intCast(out))),
1859 num,
1860 };
1861 }
1862
1863 fn readBitsNoEof(self: *@This(), comptime T: type, num: u16) !T {
1864 const b, const c = try self.readBitsTuple(T, num);
1865 if (c < num) return error.EndOfStream;
1866 return b;
1867 }
1868
1869 fn readBits(self: *@This(), comptime T: type, num: u16, out_bits: *u16) !T {
1870 const b, const c = try self.readBitsTuple(T, num);
1871 out_bits.* = c;
1872 return b;
1873 }
1874
1875 fn readBitsTuple(self: *@This(), comptime T: type, num: u16) !Bits(T) {
1876 const UT = std.meta.Int(.unsigned, @bitSizeOf(T));
1877 const U = if (@bitSizeOf(T) < 8) u8 else UT;
1878
1879 if (num <= self.count) return initBits(T, self.removeBits(@intCast(num)), num);
1880
1881 var out_count: u16 = self.count;
1882 var out: U = self.removeBits(self.count);
1883
1884 const full_bytes_left = (num - out_count) / 8;
1885
1886 for (0..full_bytes_left) |_| {
1887 const byte = takeByte(self) catch |err| switch (err) {
1888 error.EndOfStream => return initBits(T, out, out_count),
1889 };
1890
1891 const pos = @as(U, byte) << @intCast(out_count);
1892 out |= pos;
1893 out_count += 8;
1894 }
1895
1896 const bits_left = num - out_count;
1897 const keep = 8 - bits_left;
1898
1899 if (bits_left == 0) return initBits(T, out, out_count);
1900
1901 const final_byte = takeByte(self) catch |err| switch (err) {
1902 error.EndOfStream => return initBits(T, out, out_count),
1903 };
1904
1905 const pos = @as(U, final_byte & low_bit_mask[bits_left]) << @intCast(out_count);
1906 out |= pos;
1907 self.bits = final_byte >> @intCast(bits_left);
1908
1909 self.count = @intCast(keep);
1910 return initBits(T, out, num);
1911 }
1912
1913 fn takeByte(br: *BitReader) error{EndOfStream}!u8 {
1914 if (br.bytes.len - br.index == 0) return error.EndOfStream;
1915 const result = br.bytes[br.index];
1916 br.index += 1;
1917 return result;
1918 }
1919
1920 fn removeBits(self: *@This(), num: u4) u8 {
1921 if (num == 8) {
1922 self.count = 0;
1923 return self.bits;
1924 }
1925
1926 const keep = self.count - num;
1927 const bits = self.bits & low_bit_mask[num];
1928 self.bits >>= @intCast(num);
1929 self.count = keep;
1930 return bits;
1931 }
1932
1933 fn alignToByte(self: *@This()) void {
1934 self.bits = 0;
1935 self.count = 0;
1936 }
1937};
1938
1939test {
1940 _ = Table;
1941}