Commit 722e066173
Changed files (4)
lib
std
compress
lib/std/compress/xz/Decompress.zig
@@ -8,6 +8,7 @@ const Sha256 = std.crypto.hash.sha2.Sha256;
const lzma2 = std.compress.lzma2;
const Writer = std.Io.Writer;
const Reader = std.Io.Reader;
+const assert = std.debug.assert;
/// Underlying compressed data stream to pull bytes from.
input: *Reader,
@@ -28,6 +29,7 @@ pub const Error = error{
Overflow,
InvalidRangeCode,
DecompressedSizeMismatch,
+ CompressedSizeMismatch,
};
pub const Check = enum(u4) {
@@ -62,10 +64,10 @@ pub fn init(
if (!std.mem.eql(u8, magic, &.{ 0xFD, '7', 'z', 'X', 'Z', 0x00 }))
return error.NotXzStream;
- const actual_hash = Crc32.hash(try input.peek(@sizeOf(StreamFlags)));
+ const computed_checksum = Crc32.hash(try input.peek(@sizeOf(StreamFlags)));
const stream_flags = input.takeStruct(StreamFlags, .little) catch unreachable;
const stored_hash = try input.takeInt(u32, .little);
- if (actual_hash != stored_hash) return error.WrongChecksum;
+ if (computed_checksum != stored_hash) return error.WrongChecksum;
return .{
.input = input,
@@ -129,6 +131,7 @@ fn readIndirect(r: *Reader) Reader.Error!usize {
r.end = allocating.writer.end;
}
+ if (d.err != null) return error.ReadFailed;
if (d.block_count == std.math.maxInt(usize)) return error.EndOfStream;
readBlock(input, &allocating) catch |err| switch (err) {
@@ -137,7 +140,10 @@ fn readIndirect(r: *Reader) Reader.Error!usize {
return error.ReadFailed;
},
error.SuccessfulEndOfStream => {
- finish(d);
+ finish(d) catch |finish_err| {
+ d.err = finish_err;
+ return error.ReadFailed;
+ };
d.block_count = std.math.maxInt(usize);
return error.EndOfStream;
},
@@ -184,7 +190,7 @@ fn readBlock(input: *Reader, allocating: *Writer.Allocating) !void {
var packed_size: ?u64 = null;
var unpacked_size: ?u64 = null;
- {
+ const header_size = h: {
// Read the block header via peeking so that we can hash the whole thing too.
const first_byte: usize = try input.peekByte();
if (first_byte == 0) return error.SuccessfulEndOfStream;
@@ -223,95 +229,92 @@ fn readBlock(input: *Reader, allocating: *Writer.Allocating) !void {
const actual_header_size = input.seek - header_seek_start;
if (actual_header_size > declared_header_size) return error.CorruptInput;
- var remaining_bytes = declared_header_size - actual_header_size;
- while (remaining_bytes != 0) {
+ const remaining_bytes = declared_header_size - actual_header_size;
+ for (0..remaining_bytes) |_| {
if (try input.takeByte() != 0) return error.CorruptInput;
- remaining_bytes -= 1;
}
const header_slice = input.buffer[header_seek_start..][0..declared_header_size];
- const actual_hash = Crc32.hash(header_slice);
- const declared_hash = try input.takeInt(u32, .little);
- if (actual_hash != declared_hash) return error.WrongChecksum;
- }
+ const computed_checksum = Crc32.hash(header_slice);
+ const declared_checksum = try input.takeInt(u32, .little);
+ if (computed_checksum != declared_checksum) return error.WrongChecksum;
+ break :h declared_header_size;
+ };
// Compressed Data
var lzma2_decode = try lzma2.Decode.init(allocating.allocator);
+ defer lzma2_decode.deinit(allocating.allocator);
const before_size = allocating.writer.end;
- try lzma2_decode.decompress(input, allocating);
+ const packed_bytes_read = try lzma2_decode.decompress(input, allocating);
const unpacked_bytes = allocating.writer.end - before_size;
- // TODO restore this check
- //if (packed_size) |s| {
- // if (s != packed_counter.bytes_read)
- // return error.CorruptInput;
- //}
+ if (packed_size) |s| {
+ if (s != packed_bytes_read) return error.CorruptInput;
+ }
if (unpacked_size) |s| {
if (s != unpacked_bytes) return error.CorruptInput;
}
// Block Padding
- if (true) @panic("TODO account for block padding");
- //while (block_counter.bytes_read % 4 != 0) {
- // if (try block_reader.takeByte() != 0)
- // return error.CorruptInput;
- //}
-
+ const block_counter = header_size + packed_bytes_read;
+ const padding = (4 - (block_counter % 4)) % 4;
+ for (0..padding) |_| {
+ if (try input.takeByte() != 0) return error.CorruptInput;
+ }
}
-fn finish(d: *Decompress) void {
- _ = d;
- @panic("TODO");
- //const input = d.input;
- //const index_size = blk: {
- // const record_count = try input.takeLeb128(u64);
- // if (record_count != d.block_decode.block_count)
- // return error.CorruptInput;
-
- // var i: usize = 0;
- // while (i < record_count) : (i += 1) {
- // // TODO: validate records
- // _ = try std.leb.readUleb128(u64, counting_reader);
- // _ = try std.leb.readUleb128(u64, counting_reader);
- // }
-
- // while (counter.bytes_read % 4 != 0) {
- // if (try counting_reader.takeByte() != 0)
- // return error.CorruptInput;
- // }
-
- // const hash_a = hasher.hasher.final();
- // const hash_b = try counting_reader.takeInt(u32, .little);
- // if (hash_a != hash_b)
- // return error.WrongChecksum;
-
- // break :blk counter.bytes_read;
- //};
-
- //const hash_a = try d.in_reader.takeInt(u32, .little);
+fn finish(d: *Decompress) !void {
+ const input = d.input;
+ const index_size = blk: {
+ // Assume that we already peeked a zero in readBlock().
+ assert(input.buffered()[0] == 0);
+ var input_counter: u64 = 1;
+ var checksum: Crc32 = .init();
+ checksum.update(&.{0});
+ input.toss(1);
- //const hash_b = blk: {
- // var hasher = hashedReader(d.in_reader, Crc32.init());
- // const hashed_reader = hasher.reader();
+ const record_count = try countLeb128(input, u64, &input_counter, &checksum);
+ if (record_count != d.block_count)
+ return error.CorruptInput;
- // const backward_size = (@as(u64, try hashed_reader.takeInt(u32, .little)) + 1) * 4;
- // if (backward_size != index_size)
- // return error.CorruptInput;
+ for (0..record_count) |_| {
+ // TODO: validate records
+ _ = try countLeb128(input, u64, &input_counter, &checksum);
+ _ = try countLeb128(input, u64, &input_counter, &checksum);
+ }
- // var check: Check = undefined;
- // try readStreamFlags(hashed_reader, &check);
+ const padding_len = (4 - (input_counter % 4)) % 4;
+ const padding = try input.take(padding_len);
+ for (padding) |byte| {
+ if (byte != 0) return error.CorruptInput;
+ }
+ checksum.update(padding);
- // break :blk hasher.hasher.final();
- //};
+ const declared_checksum = try input.takeInt(u32, .little);
+ const computed_checksum = checksum.final();
+ if (computed_checksum != declared_checksum) return error.WrongChecksum;
- //if (hash_a != hash_b)
- // return error.WrongChecksum;
+ break :blk input_counter + padding.len + 4;
+ };
- //const magic = try d.in_reader.takeBytesNoEof(2);
- //if (!std.mem.eql(u8, &magic, &.{ 'Y', 'Z' }))
- // return error.CorruptInput;
+ const declared_checksum = try input.takeInt(u32, .little);
+ const computed_checksum = Crc32.hash(try input.peek(4 + @sizeOf(StreamFlags)));
+ if (declared_checksum != computed_checksum) return error.WrongChecksum;
+ const backward_size = (@as(u64, try input.takeInt(u32, .little)) + 1) * 4;
+ if (backward_size != index_size) return error.CorruptInput;
+ input.toss(@sizeOf(StreamFlags));
+ if (!std.mem.eql(u8, try input.takeArray(2), &.{ 'Y', 'Z' }))
+ return error.CorruptInput;
+}
- //return 0;
+fn countLeb128(reader: *Reader, comptime T: type, counter: *u64, hasher: *Crc32) !T {
+ try reader.fill(8);
+ const start = reader.seek;
+ const result = try reader.takeLeb128(T);
+ const read_slice = reader.buffer[start..reader.seek];
+ hasher.update(read_slice);
+ counter.* += read_slice.len;
+ return result;
}
lib/std/compress/xz/test.zig
@@ -22,47 +22,79 @@ fn testReader(data: []const u8, comptime expected: []const u8) !void {
try testing.expectEqualSlices(u8, expected, result);
}
-test "compressed data" {
+test "fixture good-0-empty.xz" {
try testReader(@embedFile("testdata/good-0-empty.xz"), "");
+}
- inline for ([_][]const u8{
- "good-1-check-none.xz",
- "good-1-check-crc32.xz",
- "good-1-check-crc64.xz",
- "good-1-check-sha256.xz",
- "good-2-lzma2.xz",
- "good-1-block_header-1.xz",
- "good-1-block_header-2.xz",
- "good-1-block_header-3.xz",
- }) |filename| {
- try testReader(@embedFile("testdata/" ++ filename),
- \\Hello
- \\World!
- \\
- );
- }
+const hello_world_text =
+ \\Hello
+ \\World!
+ \\
+;
- inline for ([_][]const u8{
- "good-1-lzma2-1.xz",
- "good-1-lzma2-2.xz",
- "good-1-lzma2-3.xz",
- "good-1-lzma2-4.xz",
- }) |filename| {
- try testReader(@embedFile("testdata/" ++ filename),
- \\Lorem ipsum dolor sit amet, consectetur adipisicing
- \\elit, sed do eiusmod tempor incididunt ut
- \\labore et dolore magna aliqua. Ut enim
- \\ad minim veniam, quis nostrud exercitation ullamco
- \\laboris nisi ut aliquip ex ea commodo
- \\consequat. Duis aute irure dolor in reprehenderit
- \\in voluptate velit esse cillum dolore eu
- \\fugiat nulla pariatur. Excepteur sint occaecat cupidatat
- \\non proident, sunt in culpa qui officia
- \\deserunt mollit anim id est laborum.
- \\
- );
- }
+test "fixture good-1-check-none.xz" {
+ try testReader(@embedFile("testdata/good-1-check-none.xz"), hello_world_text);
+}
+
+test "fixture good-1-check-crc32.xz" {
+ try testReader(@embedFile("testdata/good-1-check-crc32.xz"), hello_world_text);
+}
+
+test "fixture good-1-check-crc64.xz" {
+ try testReader(@embedFile("testdata/good-1-check-crc64.xz"), hello_world_text);
+}
+
+test "fixture good-1-check-sha256.xz" {
+ try testReader(@embedFile("testdata/good-1-check-sha256.xz"), hello_world_text);
+}
+
+test "fixture good-2-lzma2.xz" {
+ try testReader(@embedFile("testdata/good-2-lzma2.xz"), hello_world_text);
+}
+
+test "fixture good-1-block_header-1.xz" {
+ try testReader(@embedFile("testdata/good-1-block_header-1.xz"), hello_world_text);
+}
+
+test "fixture good-1-block_header-2.xz" {
+ try testReader(@embedFile("testdata/good-1-block_header-2.xz"), hello_world_text);
+}
+
+test "fixture good-1-block_header-3.xz" {
+ try testReader(@embedFile("testdata/good-1-block_header-3.xz"), hello_world_text);
+}
+
+const lorem_ipsum_text =
+ \\Lorem ipsum dolor sit amet, consectetur adipisicing
+ \\elit, sed do eiusmod tempor incididunt ut
+ \\labore et dolore magna aliqua. Ut enim
+ \\ad minim veniam, quis nostrud exercitation ullamco
+ \\laboris nisi ut aliquip ex ea commodo
+ \\consequat. Duis aute irure dolor in reprehenderit
+ \\in voluptate velit esse cillum dolore eu
+ \\fugiat nulla pariatur. Excepteur sint occaecat cupidatat
+ \\non proident, sunt in culpa qui officia
+ \\deserunt mollit anim id est laborum.
+ \\
+;
+
+test "fixture good-1-lzma2-1.xz" {
+ try testReader(@embedFile("testdata/good-1-lzma2-1.xz"), lorem_ipsum_text);
+}
+
+test "fixture good-1-lzma2-2.xz" {
+ try testReader(@embedFile("testdata/good-1-lzma2-2.xz"), lorem_ipsum_text);
+}
+
+test "fixture good-1-lzma2-3.xz" {
+ try testReader(@embedFile("testdata/good-1-lzma2-3.xz"), lorem_ipsum_text);
+}
+
+test "fixture good-1-lzma2-4.xz" {
+ try testReader(@embedFile("testdata/good-1-lzma2-4.xz"), lorem_ipsum_text);
+}
+test "fixture good-1-lzma2-5.xz" {
try testReader(@embedFile("testdata/good-1-lzma2-5.xz"), "");
}
lib/std/compress/lzma.zig
@@ -12,11 +12,19 @@ pub const RangeDecoder = struct {
code: u32,
pub fn init(reader: *Reader) !RangeDecoder {
+ var counter: u64 = 0;
+ return initCounting(reader, &counter);
+ }
+
+ pub fn initCounting(reader: *Reader, n_read: *u64) !RangeDecoder {
const reserved = try reader.takeByte();
+ n_read.* += 1;
if (reserved != 0) return error.InvalidRangeCode;
+ const code = try reader.takeInt(u32, .big);
+ n_read.* += 4;
return .{
.range = 0xFFFF_FFFF,
- .code = try reader.takeInt(u32, .big),
+ .code = code,
};
}
@@ -24,47 +32,47 @@ pub const RangeDecoder = struct {
return self.code == 0;
}
- fn normalize(self: *RangeDecoder, reader: *Reader) !void {
+ fn normalize(self: *RangeDecoder, reader: *Reader, n_read: *u64) !void {
if (self.range < 0x0100_0000) {
self.range <<= 8;
self.code = (self.code << 8) ^ @as(u32, try reader.takeByte());
+ n_read.* += 1;
}
}
- fn getBit(self: *RangeDecoder, reader: *Reader) !bool {
+ fn getBit(self: *RangeDecoder, reader: *Reader, n_read: *u64) !bool {
self.range >>= 1;
const bit = self.code >= self.range;
- if (bit)
- self.code -= self.range;
+ if (bit) self.code -= self.range;
- try self.normalize(reader);
+ try self.normalize(reader, n_read);
return bit;
}
- pub fn get(self: *RangeDecoder, reader: *Reader, count: usize) !u32 {
+ pub fn get(self: *RangeDecoder, reader: *Reader, count: usize, n_read: *u64) !u32 {
var result: u32 = 0;
- var i: usize = 0;
- while (i < count) : (i += 1)
- result = (result << 1) ^ @intFromBool(try self.getBit(reader));
+ for (0..count) |_| {
+ result = (result << 1) ^ @intFromBool(try self.getBit(reader, n_read));
+ }
return result;
}
- pub fn decodeBit(self: *RangeDecoder, reader: *Reader, prob: *u16) !bool {
+ pub fn decodeBit(self: *RangeDecoder, reader: *Reader, prob: *u16, n_read: *u64) !bool {
const bound = (self.range >> 11) * prob.*;
if (self.code < bound) {
prob.* += (0x800 - prob.*) >> 5;
self.range = bound;
- try self.normalize(reader);
+ try self.normalize(reader, n_read);
return false;
} else {
prob.* -= prob.* >> 5;
self.code -= bound;
self.range -= bound;
- try self.normalize(reader);
+ try self.normalize(reader, n_read);
return true;
}
}
@@ -74,11 +82,12 @@ pub const RangeDecoder = struct {
reader: *Reader,
num_bits: u5,
probs: []u16,
+ n_read: *u64,
) !u32 {
var tmp: u32 = 1;
var i: @TypeOf(num_bits) = 0;
while (i < num_bits) : (i += 1) {
- const bit = try self.decodeBit(reader, &probs[tmp]);
+ const bit = try self.decodeBit(reader, &probs[tmp], n_read);
tmp = (tmp << 1) ^ @intFromBool(bit);
}
return tmp - (@as(u32, 1) << num_bits);
@@ -90,12 +99,13 @@ pub const RangeDecoder = struct {
num_bits: u5,
probs: []u16,
offset: usize,
+ n_read: *u64,
) !u32 {
var result: u32 = 0;
var tmp: usize = 1;
var i: @TypeOf(num_bits) = 0;
while (i < num_bits) : (i += 1) {
- const bit = @intFromBool(try self.decodeBit(reader, &probs[offset + tmp]));
+ const bit = @intFromBool(try self.decodeBit(reader, &probs[offset + tmp], n_read));
tmp = (tmp << 1) ^ bit;
result ^= @as(u32, bit) << i;
}
@@ -177,13 +187,14 @@ pub const Decode = struct {
/// `CircularBuffer` or `std.compress.lzma2.AccumBuffer`.
buffer: anytype,
decoder: *RangeDecoder,
+ n_read: *u64,
) !ProcessingStatus {
const gpa = allocating.allocator;
const writer = &allocating.writer;
const pos_state = buffer.len & ((@as(usize, 1) << self.properties.pb) - 1);
- if (!try decoder.decodeBit(reader, &self.is_match[(self.state << 4) + pos_state])) {
- const byte: u8 = try self.decodeLiteral(reader, buffer, decoder);
+ if (!try decoder.decodeBit(reader, &self.is_match[(self.state << 4) + pos_state], n_read)) {
+ const byte: u8 = try self.decodeLiteral(reader, buffer, decoder, n_read);
try buffer.appendLiteral(gpa, byte, writer);
@@ -197,18 +208,18 @@ pub const Decode = struct {
}
var len: usize = undefined;
- if (try decoder.decodeBit(reader, &self.is_rep[self.state])) {
- if (!try decoder.decodeBit(reader, &self.is_rep_g0[self.state])) {
- if (!try decoder.decodeBit(reader, &self.is_rep_0long[(self.state << 4) + pos_state])) {
+ if (try decoder.decodeBit(reader, &self.is_rep[self.state], n_read)) {
+ if (!try decoder.decodeBit(reader, &self.is_rep_g0[self.state], n_read)) {
+ if (!try decoder.decodeBit(reader, &self.is_rep_0long[(self.state << 4) + pos_state], n_read)) {
self.state = if (self.state < 7) 9 else 11;
const dist = self.rep[0] + 1;
try buffer.appendLz(gpa, 1, dist, writer);
return .more;
}
} else {
- const idx: usize = if (!try decoder.decodeBit(reader, &self.is_rep_g1[self.state]))
+ const idx: usize = if (!try decoder.decodeBit(reader, &self.is_rep_g1[self.state], n_read))
1
- else if (!try decoder.decodeBit(reader, &self.is_rep_g2[self.state]))
+ else if (!try decoder.decodeBit(reader, &self.is_rep_g2[self.state], n_read))
2
else
3;
@@ -220,7 +231,7 @@ pub const Decode = struct {
self.rep[0] = dist;
}
- len = try self.rep_len_decoder.decode(reader, decoder, pos_state);
+ len = try self.rep_len_decoder.decode(reader, decoder, pos_state, n_read);
self.state = if (self.state < 7) 8 else 11;
} else {
@@ -228,11 +239,11 @@ pub const Decode = struct {
self.rep[2] = self.rep[1];
self.rep[1] = self.rep[0];
- len = try self.len_decoder.decode(reader, decoder, pos_state);
+ len = try self.len_decoder.decode(reader, decoder, pos_state, n_read);
self.state = if (self.state < 7) 7 else 10;
- const rep_0 = try self.decodeDistance(reader, decoder, len);
+ const rep_0 = try self.decodeDistance(reader, decoder, len, n_read);
self.rep[0] = rep_0;
if (self.rep[0] == 0xFFFF_FFFF) {
@@ -257,6 +268,7 @@ pub const Decode = struct {
/// `CircularBuffer` or `std.compress.lzma2.AccumBuffer`.
buffer: anytype,
decoder: *RangeDecoder,
+ n_read: *u64,
) !u8 {
const def_prev_byte = 0;
const prev_byte = @as(usize, buffer.lastOr(def_prev_byte));
@@ -275,6 +287,7 @@ pub const Decode = struct {
const bit = @intFromBool(try decoder.decodeBit(
reader,
&probs[((@as(usize, 1) + match_bit) << 8) + result],
+ n_read,
));
result = (result << 1) ^ bit;
if (match_bit != bit) {
@@ -284,10 +297,10 @@ pub const Decode = struct {
}
while (result < 0x100) {
- result = (result << 1) ^ @intFromBool(try decoder.decodeBit(reader, &probs[result]));
+ result = (result << 1) ^ @intFromBool(try decoder.decodeBit(reader, &probs[result], n_read));
}
- return @as(u8, @truncate(result - 0x100));
+ return @truncate(result - 0x100);
}
fn decodeDistance(
@@ -295,12 +308,12 @@ pub const Decode = struct {
reader: *Reader,
decoder: *RangeDecoder,
length: usize,
+ n_read: *u64,
) !usize {
const len_state = if (length > 3) 3 else length;
- const pos_slot = @as(usize, try self.pos_slot_decoder[len_state].parse(reader, decoder));
- if (pos_slot < 4)
- return pos_slot;
+ const pos_slot: usize = try self.pos_slot_decoder[len_state].parse(reader, decoder, n_read);
+ if (pos_slot < 4) return pos_slot;
const num_direct_bits = @as(u5, @intCast((pos_slot >> 1) - 1));
var result = (2 ^ (pos_slot & 1)) << num_direct_bits;
@@ -311,10 +324,11 @@ pub const Decode = struct {
num_direct_bits,
&self.pos_decoders,
result - pos_slot,
+ n_read,
);
} else {
- result += @as(usize, try decoder.get(reader, num_direct_bits - 4)) << 4;
- result += try self.align_decoder.parseReverse(reader, decoder);
+ result += @as(usize, try decoder.get(reader, num_direct_bits - 4, n_read)) << 4;
+ result += try self.align_decoder.parseReverse(reader, decoder, n_read);
}
return result;
@@ -435,16 +449,17 @@ pub const Decode = struct {
return struct {
probs: [1 << num_bits]u16 = @splat(0x400),
- pub fn parse(self: *@This(), reader: *Reader, decoder: *RangeDecoder) !u32 {
- return decoder.parseBitTree(reader, num_bits, &self.probs);
+ pub fn parse(self: *@This(), reader: *Reader, decoder: *RangeDecoder, n_read: *u64) !u32 {
+ return decoder.parseBitTree(reader, num_bits, &self.probs, n_read);
}
pub fn parseReverse(
self: *@This(),
reader: *Reader,
decoder: *RangeDecoder,
+ n_read: *u64,
) !u32 {
- return decoder.parseReverseBitTree(reader, num_bits, &self.probs, 0);
+ return decoder.parseReverseBitTree(reader, num_bits, &self.probs, 0, n_read);
}
pub fn reset(self: *@This()) void {
@@ -465,13 +480,14 @@ pub const Decode = struct {
reader: *Reader,
decoder: *RangeDecoder,
pos_state: usize,
+ n_read: *u64,
) !usize {
- if (!try decoder.decodeBit(reader, &self.choice)) {
- return @as(usize, try self.low_coder[pos_state].parse(reader, decoder));
- } else if (!try decoder.decodeBit(reader, &self.choice2)) {
- return @as(usize, try self.mid_coder[pos_state].parse(reader, decoder)) + 8;
+ if (!try decoder.decodeBit(reader, &self.choice, n_read)) {
+ return @as(usize, try self.low_coder[pos_state].parse(reader, decoder, n_read));
+ } else if (!try decoder.decodeBit(reader, &self.choice2, n_read)) {
+ return @as(usize, try self.mid_coder[pos_state].parse(reader, decoder, n_read)) + 8;
} else {
- return @as(usize, try self.high_coder.parse(reader, decoder)) + 16;
+ return @as(usize, try self.high_coder.parse(reader, decoder, n_read)) + 16;
}
}
@@ -701,7 +717,8 @@ pub const Decompress = struct {
} else if (d.range_decoder.isFinished()) {
break :process_next;
}
- switch (d.decode.process(d.input, &allocating, &d.buffer, &d.range_decoder) catch |err| switch (err) {
+ var n_read: u64 = 0;
+ switch (d.decode.process(d.input, &allocating, &d.buffer, &d.range_decoder, &n_read) catch |err| switch (err) {
error.WriteFailed => {
d.err = error.OutOfMemory;
return error.ReadFailed;
lib/std/compress/lzma2.zig
@@ -116,24 +116,29 @@ pub const Decode = struct {
self.* = undefined;
}
- pub fn decompress(d: *Decode, reader: *Reader, allocating: *Writer.Allocating) !void {
+ /// Returns how many compressed bytes were consumed.
+ pub fn decompress(d: *Decode, reader: *Reader, allocating: *Writer.Allocating) !u64 {
const gpa = allocating.allocator;
var accum = AccumBuffer.init(std.math.maxInt(usize));
defer accum.deinit(gpa);
+ var n_read: u64 = 0;
+
while (true) {
const status = try reader.takeByte();
+ n_read += 1;
switch (status) {
0 => break,
- 1 => try parseUncompressed(reader, allocating, &accum, true),
- 2 => try parseUncompressed(reader, allocating, &accum, false),
- else => try d.parseLzma(reader, allocating, &accum, status),
+ 1 => n_read += try parseUncompressed(reader, allocating, &accum, true),
+ 2 => n_read += try parseUncompressed(reader, allocating, &accum, false),
+ else => n_read += try d.parseLzma(reader, allocating, &accum, status),
}
}
try accum.finish(&allocating.writer);
+ return n_read;
}
fn parseLzma(
@@ -142,7 +147,7 @@ pub const Decode = struct {
allocating: *Writer.Allocating,
accum: *AccumBuffer,
status: u8,
- ) !void {
+ ) !u64 {
if (status & 0x80 == 0) return error.CorruptInput;
const Reset = struct {
@@ -175,15 +180,19 @@ pub const Decode = struct {
else => unreachable,
};
+ var n_read: u64 = 0;
+
const unpacked_size = blk: {
var tmp: u64 = status & 0x1F;
tmp <<= 16;
tmp |= try reader.takeInt(u16, .big);
+ n_read += 2;
break :blk tmp + 1;
};
const packed_size = blk: {
const tmp: u17 = try reader.takeInt(u16, .big);
+ n_read += 2;
break :blk tmp + 1;
};
@@ -196,6 +205,7 @@ pub const Decode = struct {
if (reset.props) {
var props = try reader.takeByte();
+ n_read += 1;
if (props >= 225) {
return error.CorruptInput;
}
@@ -216,23 +226,21 @@ pub const Decode = struct {
try ld.resetState(allocating.allocator, new_props);
}
- var range_decoder = try lzma.RangeDecoder.init(reader);
+ const start_count = n_read;
+ var range_decoder = try lzma.RangeDecoder.initCounting(reader, &n_read);
while (true) {
if (accum.len >= unpacked_size) break;
if (range_decoder.isFinished()) break;
- switch (try ld.process(reader, allocating, accum, &range_decoder)) {
+ switch (try ld.process(reader, allocating, accum, &range_decoder, &n_read)) {
.more => continue,
.finished => break,
}
}
if (accum.len != unpacked_size) return error.DecompressedSizeMismatch;
+ if (n_read - start_count != packed_size) return error.CompressedSizeMismatch;
- // TODO restore this error
- //if (counter.bytes_read != packed_size) {
- // return error.CorruptInput;
- //}
- _ = packed_size;
+ return n_read;
}
fn parseUncompressed(
@@ -240,18 +248,17 @@ pub const Decode = struct {
allocating: *Writer.Allocating,
accum: *AccumBuffer,
reset_dict: bool,
- ) !void {
+ ) !usize {
const unpacked_size = @as(u17, try reader.takeInt(u16, .big)) + 1;
if (reset_dict) try accum.reset(&allocating.writer);
const gpa = allocating.allocator;
- var i = unpacked_size;
- while (i != 0) {
+ for (0..unpacked_size) |_| {
try accum.appendByte(gpa, try reader.takeByte());
- i -= 1;
}
+ return 2 + unpacked_size;
}
};
@@ -268,6 +275,7 @@ test "decompress hello world stream" {
var result: std.Io.Writer.Allocating = .init(gpa);
defer result.deinit();
- try decode.decompress(&stream, &result);
+ const n_read = try decode.decompress(&stream, &result);
+ try std.testing.expectEqual(compressed.len, n_read);
try std.testing.expectEqualStrings(expected, result.written());
}