master
  1const std = @import("../std.zig");
  2const Allocator = std.mem.Allocator;
  3const ArrayList = std.ArrayList;
  4const lzma = std.compress.lzma;
  5const Writer = std.Io.Writer;
  6const Reader = std.Io.Reader;
  7
  8/// An accumulating buffer for LZ sequences
  9pub const AccumBuffer = struct {
 10    /// Buffer
 11    buf: ArrayList(u8),
 12    /// Buffer memory limit
 13    memlimit: usize,
 14    /// Total number of bytes sent through the buffer
 15    len: usize,
 16
 17    pub fn init(memlimit: usize) AccumBuffer {
 18        return .{
 19            .buf = .{},
 20            .memlimit = memlimit,
 21            .len = 0,
 22        };
 23    }
 24
 25    pub fn appendByte(self: *AccumBuffer, allocator: Allocator, byte: u8) !void {
 26        try self.buf.append(allocator, byte);
 27        self.len += 1;
 28    }
 29
 30    /// Reset the internal dictionary
 31    pub fn reset(self: *AccumBuffer, writer: *Writer) !void {
 32        try writer.writeAll(self.buf.items);
 33        self.buf.clearRetainingCapacity();
 34        self.len = 0;
 35    }
 36
 37    /// Retrieve the last byte or return a default
 38    pub fn lastOr(self: AccumBuffer, lit: u8) u8 {
 39        const buf_len = self.buf.items.len;
 40        return if (buf_len == 0)
 41            lit
 42        else
 43            self.buf.items[buf_len - 1];
 44    }
 45
 46    /// Retrieve the n-th last byte
 47    pub fn lastN(self: AccumBuffer, dist: usize) !u8 {
 48        const buf_len = self.buf.items.len;
 49        if (dist > buf_len) {
 50            return error.CorruptInput;
 51        }
 52
 53        return self.buf.items[buf_len - dist];
 54    }
 55
 56    /// Append a literal
 57    pub fn appendLiteral(
 58        self: *AccumBuffer,
 59        allocator: Allocator,
 60        lit: u8,
 61        writer: *Writer,
 62    ) !void {
 63        _ = writer;
 64        if (self.len >= self.memlimit) {
 65            return error.CorruptInput;
 66        }
 67        try self.buf.append(allocator, lit);
 68        self.len += 1;
 69    }
 70
 71    /// Fetch an LZ sequence (length, distance) from inside the buffer
 72    pub fn appendLz(
 73        self: *AccumBuffer,
 74        allocator: Allocator,
 75        len: usize,
 76        dist: usize,
 77        writer: *Writer,
 78    ) !void {
 79        _ = writer;
 80
 81        const buf_len = self.buf.items.len;
 82        if (dist > buf_len) return error.CorruptInput;
 83
 84        try self.buf.ensureUnusedCapacity(allocator, len);
 85        const buffer = self.buf.allocatedSlice();
 86        const src = buffer[buf_len - dist ..][0..len];
 87        const dst = buffer[buf_len..][0..len];
 88
 89        // This is not a @memmove; it intentionally repeats patterns caused by
 90        // iterating one byte at a time.
 91        for (dst, src) |*d, s| d.* = s;
 92
 93        self.buf.items.len = buf_len + len;
 94        self.len += len;
 95    }
 96
 97    pub fn finish(self: *AccumBuffer, writer: *Writer) !void {
 98        try writer.writeAll(self.buf.items);
 99        self.buf.clearRetainingCapacity();
100    }
101
102    pub fn deinit(self: *AccumBuffer, allocator: Allocator) void {
103        self.buf.deinit(allocator);
104        self.* = undefined;
105    }
106};
107
108pub const Decode = struct {
109    lzma_decode: lzma.Decode,
110
111    pub fn init(gpa: Allocator) !Decode {
112        return .{ .lzma_decode = try lzma.Decode.init(gpa, .{ .lc = 0, .lp = 0, .pb = 0 }) };
113    }
114
115    pub fn deinit(self: *Decode, gpa: Allocator) void {
116        self.lzma_decode.deinit(gpa);
117        self.* = undefined;
118    }
119
120    /// Returns how many compressed bytes were consumed.
121    pub fn decompress(d: *Decode, reader: *Reader, allocating: *Writer.Allocating) !u64 {
122        const gpa = allocating.allocator;
123
124        var accum = AccumBuffer.init(std.math.maxInt(usize));
125        defer accum.deinit(gpa);
126
127        var n_read: u64 = 0;
128
129        while (true) {
130            const status = try reader.takeByte();
131            n_read += 1;
132
133            switch (status) {
134                0 => break,
135                1 => n_read += try parseUncompressed(reader, allocating, &accum, true),
136                2 => n_read += try parseUncompressed(reader, allocating, &accum, false),
137                else => n_read += try d.parseLzma(reader, allocating, &accum, status),
138            }
139        }
140
141        try accum.finish(&allocating.writer);
142        return n_read;
143    }
144
145    fn parseLzma(
146        d: *Decode,
147        reader: *Reader,
148        allocating: *Writer.Allocating,
149        accum: *AccumBuffer,
150        status: u8,
151    ) !u64 {
152        if (status & 0x80 == 0) return error.CorruptInput;
153
154        const Reset = struct {
155            dict: bool,
156            state: bool,
157            props: bool,
158        };
159
160        const reset: Reset = switch ((status >> 5) & 0x3) {
161            0 => .{
162                .dict = false,
163                .state = false,
164                .props = false,
165            },
166            1 => .{
167                .dict = false,
168                .state = true,
169                .props = false,
170            },
171            2 => .{
172                .dict = false,
173                .state = true,
174                .props = true,
175            },
176            3 => .{
177                .dict = true,
178                .state = true,
179                .props = true,
180            },
181            else => unreachable,
182        };
183
184        var n_read: u64 = 0;
185
186        const unpacked_size = blk: {
187            var tmp: u64 = status & 0x1F;
188            tmp <<= 16;
189            tmp |= try reader.takeInt(u16, .big);
190            n_read += 2;
191            break :blk tmp + 1;
192        };
193
194        const packed_size = blk: {
195            const tmp: u17 = try reader.takeInt(u16, .big);
196            n_read += 2;
197            break :blk tmp + 1;
198        };
199
200        if (reset.dict) try accum.reset(&allocating.writer);
201
202        const ld = &d.lzma_decode;
203
204        if (reset.state) {
205            var new_props = ld.properties;
206
207            if (reset.props) {
208                var props = try reader.takeByte();
209                n_read += 1;
210                if (props >= 225) {
211                    return error.CorruptInput;
212                }
213
214                const lc = @as(u4, @intCast(props % 9));
215                props /= 9;
216                const lp = @as(u3, @intCast(props % 5));
217                props /= 5;
218                const pb = @as(u3, @intCast(props));
219
220                if (lc + lp > 4) {
221                    return error.CorruptInput;
222                }
223
224                new_props = .{ .lc = lc, .lp = lp, .pb = pb };
225            }
226
227            try ld.resetState(allocating.allocator, new_props);
228        }
229
230        const expected_unpacked_size = accum.len + unpacked_size;
231        const start_count = n_read;
232        var range_decoder = try lzma.RangeDecoder.initCounting(reader, &n_read);
233
234        while (true) {
235            if (accum.len >= expected_unpacked_size) break;
236            switch (try ld.process(reader, allocating, accum, &range_decoder, &n_read)) {
237                .more => continue,
238                .finished => break,
239            }
240        }
241        if (accum.len != expected_unpacked_size) return error.DecompressedSizeMismatch;
242        if (n_read - start_count != packed_size) return error.CompressedSizeMismatch;
243
244        return n_read;
245    }
246
247    fn parseUncompressed(
248        reader: *Reader,
249        allocating: *Writer.Allocating,
250        accum: *AccumBuffer,
251        reset_dict: bool,
252    ) !usize {
253        const unpacked_size = @as(u17, try reader.takeInt(u16, .big)) + 1;
254
255        if (reset_dict) try accum.reset(&allocating.writer);
256
257        const gpa = allocating.allocator;
258
259        for (0..unpacked_size) |_| {
260            try accum.appendByte(gpa, try reader.takeByte());
261        }
262        return 2 + unpacked_size;
263    }
264};
265
266test "decompress hello world stream" {
267    const expected = "Hello\nWorld!\n";
268    const compressed = &[_]u8{ 0x01, 0x00, 0x05, 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x0A, 0x02, 0x00, 0x06, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, 0x0A, 0x00 };
269
270    const gpa = std.testing.allocator;
271
272    var decode = try Decode.init(gpa);
273    defer decode.deinit(gpa);
274
275    var stream: std.Io.Reader = .fixed(compressed);
276    var result: std.Io.Writer.Allocating = .init(gpa);
277    defer result.deinit();
278
279    const n_read = try decode.decompress(&stream, &result);
280    try std.testing.expectEqual(compressed.len, n_read);
281    try std.testing.expectEqualStrings(expected, result.written());
282}