Commit 4dfca01de4

Jacob Young <jacobly0@users.noreply.github.com>
2024-01-29 14:12:19
gzip: implement compression
1 parent 27d2d8e
lib/std/compress/deflate/compressor.zig
@@ -733,7 +733,7 @@ pub fn Compressor(comptime WriterType: anytype) type {
         }
 
         /// Writes the compressed form of `input` to the underlying writer.
-        pub fn write(self: *Self, input: []const u8) !usize {
+        pub fn write(self: *Self, input: []const u8) Error!usize {
             var buf = input;
 
             // writes data to hm_bw, which will eventually write the
@@ -756,7 +756,7 @@ pub fn Compressor(comptime WriterType: anytype) type {
         /// If the underlying writer returns an error, `flush()` returns that error.
         ///
         /// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH.
-        pub fn flush(self: *Self) !void {
+        pub fn flush(self: *Self) Error!void {
             self.sync = true;
             try self.step();
             try self.hm_bw.writeStoredHeader(0, false);
@@ -956,7 +956,7 @@ pub fn Compressor(comptime WriterType: anytype) type {
         }
 
         /// Writes any pending data to the underlying writer.
-        pub fn close(self: *Self) !void {
+        pub fn close(self: *Self) Error!void {
             self.sync = true;
             try self.step();
             try self.hm_bw.writeStoredHeader(0, true);
lib/std/compress/deflate/compressor_test.zig
@@ -86,7 +86,7 @@ fn testSync(level: deflate.Compression, input: []const u8) !void {
             read = try decomp.reader().readAll(&final);
             try testing.expectEqual(@as(usize, 0), read); // expect ended stream to return 0 bytes
 
-            _ = decomp.close();
+            try decomp.close();
         }
     }
 
@@ -102,7 +102,7 @@ fn testSync(level: deflate.Compression, input: []const u8) !void {
     defer testing.allocator.free(decompressed);
 
     _ = try decomp.reader().readAll(decompressed);
-    _ = decomp.close();
+    try decomp.close();
 
     try testing.expectEqualSlices(u8, input, decompressed);
 }
@@ -477,7 +477,7 @@ test "inflate reset" {
         .readAllAlloc(testing.allocator, math.maxInt(usize));
     defer testing.allocator.free(decompressed_1);
 
-    _ = decomp.close();
+    try decomp.close();
 
     try testing.expectEqualSlices(u8, strings[0], decompressed_0);
     try testing.expectEqualSlices(u8, strings[1], decompressed_1);
@@ -524,7 +524,7 @@ test "inflate reset dictionary" {
         .readAllAlloc(testing.allocator, math.maxInt(usize));
     defer testing.allocator.free(decompressed_1);
 
-    _ = decomp.close();
+    try decomp.close();
 
     try testing.expectEqualSlices(u8, strings[0], decompressed_0);
     try testing.expectEqualSlices(u8, strings[1], decompressed_1);
lib/std/compress/deflate/decompressor.zig
@@ -477,11 +477,10 @@ pub fn Decompressor(comptime ReaderType: type) type {
             }
         }
 
-        pub fn close(self: *Self) ?Error {
-            if (self.err == @as(?Error, error.EndOfStreamWithNoError)) {
-                return null;
+        pub fn close(self: *Self) Error!void {
+            if (self.err) |err| {
+                if (err != error.EndOfStreamWithNoError) return err;
             }
-            return self.err;
         }
 
         // RFC 1951 section 3.2.7.
@@ -880,7 +879,7 @@ pub fn Decompressor(comptime ReaderType: type) type {
 
         /// Replaces the inner reader and dictionary with new_reader and new_dict.
         /// new_reader must be of the same type as the reader being replaced.
-        pub fn reset(s: *Self, new_reader: ReaderType, new_dict: ?[]const u8) !void {
+        pub fn reset(s: *Self, new_reader: ReaderType, new_dict: ?[]const u8) Error!void {
             s.inner_reader = new_reader;
             s.step = nextBlock;
             s.err = null;
@@ -920,9 +919,7 @@ test "confirm decompressor resets" {
         const buf = try decomp.reader().readAllAlloc(std.testing.allocator, 1024 * 100);
         defer std.testing.allocator.free(buf);
 
-        if (decomp.close()) |err| {
-            return err;
-        }
+        try decomp.close();
 
         try decomp.reset(stream.reader(), null);
     }
lib/std/compress/deflate/deflate_fast_test.zig
@@ -83,7 +83,7 @@ test "best speed" {
                 defer decomp.deinit();
 
                 const read = try decomp.reader().readAll(decompressed);
-                _ = decomp.close();
+                try decomp.close();
 
                 try testing.expectEqual(want.items.len, read);
                 try testing.expectEqualSlices(u8, want.items, decompressed);
@@ -150,7 +150,7 @@ test "best speed max match offset" {
                 var decomp = try inflate.decompressor(testing.allocator, fib.reader(), null);
                 defer decomp.deinit();
                 const read = try decomp.reader().readAll(decompressed);
-                _ = decomp.close();
+                try decomp.close();
 
                 try testing.expectEqual(src.len, read);
                 try testing.expectEqualSlices(u8, src, decompressed);
lib/std/compress/deflate/huffman_bit_writer.zig
@@ -124,7 +124,8 @@ pub fn HuffmanBitWriter(comptime WriterType: type) type {
             if (self.err) {
                 return;
             }
-            self.bytes_written += try self.inner_writer.write(b);
+            try self.inner_writer.writeAll(b);
+            self.bytes_written += b.len;
         }
 
         fn writeBits(self: *Self, b: u32, nb: u32) Error!void {
lib/std/compress/testdata/rfc1952.txt.gz
Binary file
lib/std/compress/gzip.zig
@@ -1,5 +1,5 @@
 //
-// Decompressor for GZIP data streams (RFC1952)
+// Compressor/Decompressor for GZIP data streams (RFC1952)
 
 const std = @import("../std.zig");
 const io = std.io;
@@ -8,6 +8,8 @@ const testing = std.testing;
 const mem = std.mem;
 const deflate = std.compress.deflate;
 
+const magic = &[2]u8{ 0x1f, 0x8b };
+
 // Flags for the FLG field in the header
 const FTEXT = 1 << 0;
 const FHCRC = 1 << 1;
@@ -17,6 +19,14 @@ const FCOMMENT = 1 << 4;
 
 const max_string_len = 1024;
 
+pub const Header = struct {
+    extra: ?[]const u8 = null,
+    filename: ?[]const u8 = null,
+    comment: ?[]const u8 = null,
+    modification_time: u32 = 0,
+    operating_system: u8 = 255,
+};
+
 pub fn Decompress(comptime ReaderType: type) type {
     return struct {
         const Self = @This();
@@ -30,25 +40,19 @@ pub fn Decompress(comptime ReaderType: type) type {
         inflater: deflate.Decompressor(ReaderType),
         in_reader: ReaderType,
         hasher: std.hash.Crc32,
-        read_amt: usize,
-
-        info: struct {
-            extra: ?[]const u8,
-            filename: ?[]const u8,
-            comment: ?[]const u8,
-            modification_time: u32,
-            operating_system: u8,
-        },
+        read_amt: u32,
+
+        info: Header,
 
-        fn init(allocator: mem.Allocator, source: ReaderType) !Self {
-            var hasher = std.compress.hashedReader(source, std.hash.Crc32.init());
+        fn init(allocator: mem.Allocator, in_reader: ReaderType) !Self {
+            var hasher = std.compress.hashedReader(in_reader, std.hash.Crc32.init());
             const hashed_reader = hasher.reader();
 
             // gzip header format is specified in RFC1952
             const header = try hashed_reader.readBytesNoEof(10);
 
             // Check the ID1/ID2 fields
-            if (header[0] != 0x1f or header[1] != 0x8b)
+            if (!std.mem.eql(u8, header[0..2], magic))
                 return error.BadHeader;
 
             const CM = header[2];
@@ -88,15 +92,15 @@ pub fn Decompress(comptime ReaderType: type) type {
             errdefer if (comment) |p| allocator.free(p);
 
             if (FLG & FHCRC != 0) {
-                const hash = try source.readInt(u16, .little);
+                const hash = try in_reader.readInt(u16, .little);
                 if (hash != @as(u16, @truncate(hasher.hasher.final())))
                     return error.WrongChecksum;
             }
 
-            return Self{
+            return .{
                 .allocator = allocator,
-                .inflater = try deflate.decompressor(allocator, source, null),
-                .in_reader = source,
+                .inflater = try deflate.decompressor(allocator, in_reader, null),
+                .in_reader = in_reader,
                 .hasher = std.hash.Crc32.init(),
                 .info = .{
                     .filename = filename,
@@ -119,7 +123,7 @@ pub fn Decompress(comptime ReaderType: type) type {
                 self.allocator.free(comment);
         }
 
-        // Implements the io.Reader interface
+        /// Implements the io.Reader interface
         pub fn read(self: *Self, buffer: []u8) Error!usize {
             if (buffer.len == 0)
                 return 0;
@@ -128,10 +132,12 @@ pub fn Decompress(comptime ReaderType: type) type {
             const r = try self.inflater.read(buffer);
             if (r != 0) {
                 self.hasher.update(buffer[0..r]);
-                self.read_amt += r;
+                self.read_amt +%= @truncate(r);
                 return r;
             }
 
+            try self.inflater.close();
+
             // We've reached the end of stream, check if the checksum matches
             const hash = try self.in_reader.readInt(u32, .little);
             if (hash != self.hasher.final())
@@ -139,7 +145,7 @@ pub fn Decompress(comptime ReaderType: type) type {
 
             // The ISIZE field is the size of the uncompressed input modulo 2^32
             const input_size = try self.in_reader.readInt(u32, .little);
-            if (self.read_amt & 0xffffffff != input_size)
+            if (self.read_amt != input_size)
                 return error.CorruptedData;
 
             return 0;
@@ -155,7 +161,117 @@ pub fn decompress(allocator: mem.Allocator, reader: anytype) !Decompress(@TypeOf
     return Decompress(@TypeOf(reader)).init(allocator, reader);
 }
 
-fn testReader(data: []const u8, comptime expected: []const u8) !void {
+pub const CompressOptions = struct {
+    header: Header = .{},
+    hash_header: bool = true,
+    level: deflate.Compression = .default_compression,
+};
+
+pub fn Compress(comptime WriterType: type) type {
+    return struct {
+        const Self = @This();
+
+        pub const Error = WriterType.Error ||
+            deflate.Compressor(WriterType).Error;
+        pub const Writer = io.Writer(*Self, Error, write);
+
+        allocator: mem.Allocator,
+        deflater: deflate.Compressor(WriterType),
+        out_writer: WriterType,
+        hasher: std.hash.Crc32,
+        write_amt: u32,
+
+        fn init(allocator: mem.Allocator, out_writer: WriterType, options: CompressOptions) !Self {
+            var hasher = std.compress.hashedWriter(out_writer, std.hash.Crc32.init());
+            const hashed_writer = hasher.writer();
+
+            // ID1/ID2
+            try hashed_writer.writeAll(magic);
+            // CM
+            try hashed_writer.writeByte(8);
+            // Flags
+            try hashed_writer.writeByte(
+                @as(u8, if (options.hash_header) FHCRC else 0) |
+                    @as(u8, if (options.header.extra) |_| FEXTRA else 0) |
+                    @as(u8, if (options.header.filename) |_| FNAME else 0) |
+                    @as(u8, if (options.header.comment) |_| FCOMMENT else 0),
+            );
+            // Modification time
+            try hashed_writer.writeInt(u32, options.header.modification_time, .little);
+            // Extra flags
+            try hashed_writer.writeByte(0);
+            // Operating system
+            try hashed_writer.writeByte(options.header.operating_system);
+
+            if (options.header.extra) |extra| {
+                try hashed_writer.writeInt(u16, @intCast(extra.len), .little);
+                try hashed_writer.writeAll(extra);
+            }
+
+            if (options.header.filename) |filename| {
+                try hashed_writer.writeAll(filename);
+                try hashed_writer.writeByte(0);
+            }
+
+            if (options.header.comment) |comment| {
+                try hashed_writer.writeAll(comment);
+                try hashed_writer.writeByte(0);
+            }
+
+            if (options.hash_header) {
+                try out_writer.writeInt(
+                    u16,
+                    @truncate(hasher.hasher.final()),
+                    .little,
+                );
+            }
+
+            return .{
+                .allocator = allocator,
+                .deflater = try deflate.compressor(allocator, out_writer, .{ .level = options.level }),
+                .out_writer = out_writer,
+                .hasher = std.hash.Crc32.init(),
+                .write_amt = 0,
+            };
+        }
+
+        pub fn deinit(self: *Self) void {
+            self.deflater.deinit();
+        }
+
+        /// Implements the io.Writer interface
+        pub fn write(self: *Self, buffer: []const u8) Error!usize {
+            if (buffer.len == 0)
+                return 0;
+
+            // Write to the compressed stream and update the computed checksum
+            const r = try self.deflater.write(buffer);
+            self.hasher.update(buffer[0..r]);
+            self.write_amt +%= @truncate(r);
+            return r;
+        }
+
+        pub fn writer(self: *Self) Writer {
+            return .{ .context = self };
+        }
+
+        pub fn flush(self: *Self) Error!void {
+            try self.deflater.flush();
+        }
+
+        pub fn close(self: *Self) Error!void {
+            try self.deflater.close();
+            try self.out_writer.writeInt(u32, self.hasher.final(), .little);
+            try self.out_writer.writeInt(u32, self.write_amt, .little);
+        }
+    };
+}
+
+pub fn compress(allocator: mem.Allocator, writer: anytype, options: CompressOptions) !Compress(@TypeOf(writer)) {
+    return Compress(@TypeOf(writer)).init(allocator, writer, options);
+}
+
+fn testReader(expected: []const u8, data: []const u8) !void {
     var in_stream = io.fixedBufferStream(data);
 
     var gzip_stream = try decompress(testing.allocator, in_stream.reader());
@@ -169,70 +285,91 @@ fn testReader(data: []const u8, comptime expected: []const u8) !void {
     try testing.expectEqualSlices(u8, expected, buf);
 }
 
+fn testWriter(expected: []const u8, data: []const u8, options: CompressOptions) !void {
+    var actual = std.ArrayList(u8).init(testing.allocator);
+    defer actual.deinit();
+
+    var gzip_stream = try compress(testing.allocator, actual.writer(), options);
+    defer gzip_stream.deinit();
+
+    // Write and compress the whole file
+    try gzip_stream.writer().writeAll(data);
+    try gzip_stream.close();
+
+    // Check against the reference
+    try testing.expectEqualSlices(u8, expected, actual.items);
+}
+
 // All the test cases are obtained by compressing the RFC1952 text
 //
 // https://tools.ietf.org/rfc/rfc1952.txt length=25037 bytes
 // SHA256=164ef0897b4cbec63abf1b57f069f3599bd0fb7c72c2a4dee21bd7e03ec9af67
 test "compressed data" {
-    try testReader(
-        @embedFile("testdata/rfc1952.txt.gz"),
-        @embedFile("testdata/rfc1952.txt"),
-    );
+    const plain = @embedFile("testdata/rfc1952.txt");
+    const compressed = @embedFile("testdata/rfc1952.txt.gz");
+    try testReader(plain, compressed);
+    try testWriter(compressed, plain, .{
+        .header = .{
+            .filename = "rfc1952.txt",
+            .modification_time = 1706533053,
+            .operating_system = 3,
+        },
+    });
 }
 
 test "sanity checks" {
     // Truncated header
     try testing.expectError(
         error.EndOfStream,
-        testReader(&[_]u8{ 0x1f, 0x8B }, ""),
+        testReader(undefined, &[_]u8{ 0x1f, 0x8B }),
     );
     // Wrong CM
     try testing.expectError(
         error.InvalidCompression,
-        testReader(&[_]u8{
+        testReader(undefined, &[_]u8{
             0x1f, 0x8b, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x03,
-        }, ""),
+        }),
     );
     // Wrong checksum
     try testing.expectError(
         error.WrongChecksum,
-        testReader(&[_]u8{
+        testReader(undefined, &[_]u8{
             0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x01,
             0x00, 0x00, 0x00, 0x00,
-        }, ""),
+        }),
     );
     // Truncated checksum
     try testing.expectError(
         error.EndOfStream,
-        testReader(&[_]u8{
+        testReader(undefined, &[_]u8{
             0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00,
-        }, ""),
+        }),
     );
     // Wrong initial size
     try testing.expectError(
         error.CorruptedData,
-        testReader(&[_]u8{
+        testReader(undefined, &[_]u8{
             0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x00, 0x00, 0x01,
-        }, ""),
+        }),
     );
     // Truncated initial size field
     try testing.expectError(
         error.EndOfStream,
-        testReader(&[_]u8{
+        testReader(undefined, &[_]u8{
             0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
             0x00, 0x00, 0x00,
-        }, ""),
+        }),
     );
 }
 
 test "header checksum" {
-    try testReader(&[_]u8{
+    try testReader("", &[_]u8{
         // GZIP header
         0x1f, 0x8b, 0x08, 0x12, 0x00, 0x09, 0x6e, 0x88, 0x00, 0xff, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x00,
 
@@ -241,5 +378,5 @@ test "header checksum" {
 
         // GZIP data
         0x01, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
-    }, "");
+    });
 }
lib/std/compress.zig
@@ -21,7 +21,7 @@ pub fn HashedReader(
 
         pub fn read(self: *@This(), buf: []u8) Error!usize {
             const amt = try self.child_reader.read(buf);
-            self.hasher.update(buf);
+            self.hasher.update(buf[0..amt]);
             return amt;
         }
 
@@ -38,6 +38,36 @@ pub fn hashedReader(
     return .{ .child_reader = reader, .hasher = hasher };
 }
 
+pub fn HashedWriter(
+    comptime WriterType: anytype,
+    comptime HasherType: anytype,
+) type {
+    return struct {
+        child_writer: WriterType,
+        hasher: HasherType,
+
+        pub const Error = WriterType.Error;
+        pub const Writer = std.io.Writer(*@This(), Error, write);
+
+        pub fn write(self: *@This(), buf: []const u8) Error!usize {
+            const amt = try self.child_writer.write(buf);
+            self.hasher.update(buf[0..amt]);
+            return amt;
+        }
+
+        pub fn writer(self: *@This()) Writer {
+            return .{ .context = self };
+        }
+    };
+}
+
+pub fn hashedWriter(
+    writer: anytype,
+    hasher: anytype,
+) HashedWriter(@TypeOf(writer), @TypeOf(hasher)) {
+    return .{ .child_writer = writer, .hasher = hasher };
+}
+
 test {
     _ = deflate;
     _ = gzip;