Commit 622a364715

fn ⌃ ⌥ <70830482+FnControlOption@users.noreply.github.com>
2023-02-05 17:23:51
Implement std.io.Reader for LZMA1
1 parent e03d6c4
Changed files (5)
lib
lib/std/compress/lzma/decode/lzbuffer.zig
@@ -98,6 +98,7 @@ pub const LzAccumBuffer = struct {
 
     pub fn finish(self: *Self, writer: anytype) !void {
         try writer.writeAll(self.buf.items);
+        self.buf.clearRetainingCapacity();
     }
 
     pub fn deinit(self: *Self, allocator: Allocator) void {
@@ -216,6 +217,7 @@ pub const LzCircularBuffer = struct {
     pub fn finish(self: *Self, writer: anytype) !void {
         if (self.cursor > 0) {
             try writer.writeAll(self.buf.items[0..self.cursor]);
+            self.cursor = 0;
         }
     }
 
lib/std/compress/lzma/decode.zig
@@ -280,26 +280,29 @@ pub const DecoderState = struct {
         writer: anytype,
         buffer: anytype,
         decoder: *RangeDecoder,
-    ) !void {
-        while (true) {
+    ) !ProcessingStatus {
+        process_next: {
             if (self.unpacked_size) |unpacked_size| {
                 if (buffer.len >= unpacked_size) {
-                    break;
+                    break :process_next;
                 }
             } else if (decoder.isFinished()) {
-                break;
+                break :process_next;
             }
 
-            if (try self.processNext(allocator, reader, writer, buffer, decoder) == .finished) {
-                break;
+            switch (try self.processNext(allocator, reader, writer, buffer, decoder)) {
+                .continue_ => return .continue_,
+                .finished => break :process_next,
             }
         }
 
-        if (self.unpacked_size) |len| {
-            if (len != buffer.len) {
+        if (self.unpacked_size) |unpacked_size| {
+            if (buffer.len != unpacked_size) {
                 return error.CorruptInput;
             }
         }
+
+        return .finished;
     }
 
     fn decodeLiteral(
@@ -374,36 +377,3 @@ pub const DecoderState = struct {
         return result;
     }
 };
-
-pub const Decoder = struct {
-    params: Params,
-    memlimit: usize,
-    state: DecoderState,
-
-    pub fn init(allocator: Allocator, params: Params, memlimit: ?usize) !Decoder {
-        return Decoder{
-            .params = params,
-            .memlimit = memlimit orelse math.maxInt(usize),
-            .state = try DecoderState.init(allocator, params.properties, params.unpacked_size),
-        };
-    }
-
-    pub fn deinit(self: *Decoder, allocator: Allocator) void {
-        self.state.deinit(allocator);
-        self.* = undefined;
-    }
-
-    pub fn decompress(
-        self: *Decoder,
-        allocator: Allocator,
-        reader: anytype,
-        writer: anytype,
-    ) !void {
-        var buffer = LzCircularBuffer.init(self.params.dict_size, self.memlimit);
-        defer buffer.deinit(allocator);
-
-        var decoder = try RangeDecoder.init(reader);
-        try self.state.process(allocator, reader, writer, &buffer, &decoder);
-        try buffer.finish(writer);
-    }
-};
lib/std/compress/lzma/test.zig
@@ -1,22 +1,24 @@
 const std = @import("../../std.zig");
 const lzma = @import("../lzma.zig");
 
-fn testDecompress(compressed: []const u8, writer: anytype) !void {
+fn testDecompress(compressed: []const u8) ![]u8 {
     const allocator = std.testing.allocator;
     var stream = std.io.fixedBufferStream(compressed);
-    try lzma.decompress(allocator, stream.reader(), writer, .{});
+    var decompressor = try lzma.decompress(allocator, stream.reader());
+    defer decompressor.deinit();
+    const reader = decompressor.reader();
+    return reader.readAllAlloc(allocator, std.math.maxInt(usize));
 }
 
 fn testDecompressEqual(expected: []const u8, compressed: []const u8) !void {
     const allocator = std.testing.allocator;
-    var decomp = std.ArrayList(u8).init(allocator);
-    defer decomp.deinit();
-    try testDecompress(compressed, decomp.writer());
-    try std.testing.expectEqualSlices(u8, expected, decomp.items);
+    const decomp = try testDecompress(compressed);
+    defer allocator.free(decomp);
+    try std.testing.expectEqualSlices(u8, expected, decomp);
 }
 
 fn testDecompressError(expected: anyerror, compressed: []const u8) !void {
-    return std.testing.expectError(expected, testDecompress(compressed, std.io.null_writer));
+    return std.testing.expectError(expected, testDecompress(compressed));
 }
 
 test "LZMA: decompress empty world" {
lib/std/compress/lzma2/decode.zig
@@ -141,7 +141,7 @@ pub const Decoder = struct {
         const counter_reader = counter.reader();
 
         var rangecoder = try RangeDecoder.init(counter_reader);
-        try self.lzma_state.process(allocator, counter_reader, writer, accum, &rangecoder);
+        while (try self.lzma_state.process(allocator, counter_reader, writer, accum, &rangecoder) == .continue_) {}
 
         if (counter.bytes_read != packed_size) {
             return error.CorruptInput;
lib/std/compress/lzma.zig
@@ -1,4 +1,6 @@
 const std = @import("../std.zig");
+const math = std.math;
+const mem = std.mem;
 const Allocator = std.mem.Allocator;
 
 pub const decode = @import("lzma/decode.zig");
@@ -6,13 +8,80 @@ pub const decode = @import("lzma/decode.zig");
 pub fn decompress(
     allocator: Allocator,
     reader: anytype,
-    writer: anytype,
+) !Decompress(@TypeOf(reader)) {
+    return decompressWithOptions(allocator, reader, .{});
+}
+
+pub fn decompressWithOptions(
+    allocator: Allocator,
+    reader: anytype,
     options: decode.Options,
-) !void {
+) !Decompress(@TypeOf(reader)) {
     const params = try decode.Params.readHeader(reader, options);
-    var decoder = try decode.Decoder.init(allocator, params, options.memlimit);
-    defer decoder.deinit(allocator);
-    return decoder.decompress(allocator, reader, writer);
+    return Decompress(@TypeOf(reader)).init(allocator, reader, params, options.memlimit);
+}
+
+pub fn Decompress(comptime ReaderType: type) type {
+    return struct {
+        const Self = @This();
+
+        pub const Error =
+            ReaderType.Error ||
+            Allocator.Error ||
+            error{ CorruptInput, EndOfStream, Overflow };
+
+        pub const Reader = std.io.Reader(*Self, Error, read);
+
+        allocator: Allocator,
+        in_reader: ReaderType,
+        to_read: std.ArrayListUnmanaged(u8),
+
+        buffer: decode.lzbuffer.LzCircularBuffer,
+        decoder: decode.rangecoder.RangeDecoder,
+        state: decode.DecoderState,
+
+        pub fn init(allocator: Allocator, source: ReaderType, params: decode.Params, memlimit: ?usize) !Self {
+            return Self{
+                .allocator = allocator,
+                .in_reader = source,
+                .to_read = .{},
+
+                .buffer = decode.lzbuffer.LzCircularBuffer.init(params.dict_size, memlimit orelse math.maxInt(usize)),
+                .decoder = try decode.rangecoder.RangeDecoder.init(source),
+                .state = try decode.DecoderState.init(allocator, params.properties, params.unpacked_size),
+            };
+        }
+
+        pub fn reader(self: *Self) Reader {
+            return .{ .context = self };
+        }
+
+        pub fn deinit(self: *Self) void {
+            self.to_read.deinit(self.allocator);
+            self.buffer.deinit(self.allocator);
+            self.state.deinit(self.allocator);
+            self.* = undefined;
+        }
+
+        pub fn read(self: *Self, output: []u8) Error!usize {
+            const writer = self.to_read.writer(self.allocator);
+            while (self.to_read.items.len < output.len) {
+                switch (try self.state.process(self.allocator, self.in_reader, writer, &self.buffer, &self.decoder)) {
+                    .continue_ => {},
+                    .finished => {
+                        try self.buffer.finish(writer);
+                        break;
+                    },
+                }
+            }
+            const input = self.to_read.items;
+            const n = math.min(input.len, output.len);
+            mem.copy(u8, output[0..n], input[0..n]);
+            mem.copy(u8, input, input[n..]);
+            self.to_read.shrinkRetainingCapacity(input.len - n);
+            return n;
+        }
+    };
 }
 
 test {