Commit 6bfa7bf197

Igor Anić <igor.anic@gmail.com>
2023-12-01 18:26:31
tar: use scratch buffer for file names
That makes names strings stable during the iteration. Otherwise string buffers can be overwritten while reading file content.
1 parent 6e7a39c
Changed files (1)
lib
lib/std/tar.zig
@@ -66,6 +66,7 @@ pub const Options = struct {
 };
 
 const BLOCK_SIZE = 512;
+const MAX_HEADER_NAME_SIZE = 100 + 1 + 155; // name(100) + separator(1) + prefix(155)
 
 pub const Header = struct {
     bytes: *const [BLOCK_SIZE]u8,
@@ -90,16 +91,14 @@ pub const Header = struct {
     };
 
     /// Includes prefix concatenated, if any.
-    /// Return value may point into Header buffer, or might point into the
-    /// argument buffer.
     /// TODO: check against "../" and other nefarious things
-    pub fn fullFileName(header: Header, buffer: *[std.fs.MAX_PATH_BYTES]u8) ![]const u8 {
+    pub fn fullName(header: Header, buffer: *[MAX_HEADER_NAME_SIZE]u8) ![]const u8 {
         const n = name(header);
-        if (!is_ustar(header))
-            return n;
         const p = prefix(header);
-        if (p.len == 0)
-            return n;
+        if (!is_ustar(header) or p.len == 0) {
+            @memcpy(buffer[0..n.len], n);
+            return buffer[0..n.len];
+        }
         @memcpy(buffer[0..p.len], p);
         buffer[p.len] = '/';
         @memcpy(buffer[p.len + 1 ..][0..n.len], n);
@@ -180,7 +179,7 @@ pub const Header = struct {
     }
 
     // Checks calculated chksum with value of chksum field.
-    // Returns error or chksum value.
+    // Returns error or valid chksum value.
     // Zero value indicates empty block.
     pub fn checkChksum(header: Header) !u64 {
         const field = try header.chksum();
@@ -190,7 +189,7 @@ pub const Header = struct {
     }
 };
 
-// break string on first null char
+// Breaks string on first null char.
 fn nullStr(str: []const u8) []const u8 {
     for (str, 0..) |c, i| {
         if (c == 0) return str[0..i];
@@ -198,14 +197,10 @@ fn nullStr(str: []const u8) []const u8 {
     return str;
 }
 
-// File size rounded to te block boundary.
-inline fn roundedFileSize(file_size: usize) usize {
-    return std.mem.alignForward(usize, file_size, BLOCK_SIZE);
-}
-
 // Number of padding bytes in the last file block.
-inline fn filePadding(file_size: usize) usize {
-    return roundedFileSize(file_size) - file_size;
+inline fn blockPadding(size: usize) usize {
+    const block_rounded = std.mem.alignForward(usize, size, BLOCK_SIZE); // size rounded to te block boundary
+    return block_rounded - size;
 }
 
 fn BufferedReader(comptime ReaderType: type) type {
@@ -217,44 +212,38 @@ fn BufferedReader(comptime ReaderType: type) type {
 
         const Self = @This();
 
-        fn readChunk(self: *Self, count: usize) ![]const u8 {
-            self.ensureCapacity(BLOCK_SIZE * 2);
-            const ask = @min(self.buffer.len - self.end, count -| (self.end - self.start));
-            self.end += try self.unbuffered_reader.readAtLeast(self.buffer[self.end..], ask);
-            return self.buffer[self.start..self.end];
+        // Fills buffer from underlaying reader.
+        fn fillBuffer(self: *Self) !void {
+            self.removeUsed();
+            self.end += try self.unbuffered_reader.read(self.buffer[self.end..]);
         }
 
-        // Returns slice of size count or part of it.
+        // Returns slice of size count or how much fits into buffer.
         pub fn readSlice(self: *Self, count: usize) ![]const u8 {
             if (count <= self.end - self.start) {
-                // fastpath, we have enough bytes in buffer
                 return self.buffer[self.start .. self.start + count];
             }
-
-            const chunk_size = roundedFileSize(count) + BLOCK_SIZE;
-            const temp = try self.readChunk(chunk_size);
-            if (temp.len == 0) return error.UnexpectedEndOfStream;
-            return temp[0..@min(count, temp.len)];
+            try self.fillBuffer();
+            const buf = self.buffer[self.start..self.end];
+            if (buf.len == 0) return error.UnexpectedEndOfStream;
+            return buf[0..@min(count, buf.len)];
         }
 
-        // Returns tar header block, 512 bytes. Before reading advances buffer
-        // for padding of the previous block, to position reader at the start of
-        // new block. After reading advances for block size, to position reader
-        // at the start of the file body.
-        pub fn readBlock(self: *Self, padding: usize) !?[]const u8 {
+        // Returns tar header block, 512 bytes, or null if eof. Before reading
+        // advances buffer for padding of the previous block, to position reader
+        // at the start of new block. After reading advances for block size, to
+        // position reader at the start of the file content.
+        pub fn readHeader(self: *Self, padding: usize) !?[]const u8 {
             try self.skip(padding);
-            const block_bytes = try self.readChunk(BLOCK_SIZE * 2);
-            switch (block_bytes.len) {
-                0 => return null,
-                1...(BLOCK_SIZE - 1) => return error.UnexpectedEndOfStream,
-                else => {},
-            }
+            const buf = self.readSlice(BLOCK_SIZE) catch return null;
+            if (buf.len < BLOCK_SIZE) return error.UnexpectedEndOfStream;
             self.advance(BLOCK_SIZE);
-            return block_bytes[0..BLOCK_SIZE];
+            return buf[0..BLOCK_SIZE];
         }
 
-        // Retruns byte at current position in buffer.
+        // Returns byte at current position in buffer.
         pub fn readByte(self: *@This()) u8 {
+            assert(self.start < self.end);
             return self.buffer[self.start];
         }
 
@@ -275,78 +264,36 @@ fn BufferedReader(comptime ReaderType: type) type {
             }
         }
 
-        inline fn ensureCapacity(self: *Self, count: usize) void {
-            if (self.buffer.len - self.start < count) {
-                const dest_end = self.end - self.start;
-                @memcpy(self.buffer[0..dest_end], self.buffer[self.start..self.end]);
-                self.end = dest_end;
-                self.start = 0;
-            }
+        // Removes used part of the buffer.
+        inline fn removeUsed(self: *Self) void {
+            const dest_end = self.end - self.start;
+            if (self.start == 0 or dest_end > self.start) return;
+            @memcpy(self.buffer[0..dest_end], self.buffer[self.start..self.end]);
+            self.end = dest_end;
+            self.start = 0;
         }
 
-        // Write count bytes to the writer.
+        // Writes count bytes to the writer. Advances reader.
         pub fn write(self: *Self, writer: anytype, count: usize) !void {
-            if (self.read(count)) |buf| {
-                try writer.writeAll(buf);
-                return;
-            }
-            var rdr = self.sliceReader(count);
-            while (try rdr.next()) |slice| {
+            var pos: usize = 0;
+            while (pos < count) {
+                const slice = try self.readSlice(count - pos);
                 try writer.writeAll(slice);
+                self.advance(slice.len);
+                pos += slice.len;
             }
         }
 
-        // Copy dst.len bytes into dst buffer.
+        // Copies dst.len bytes into dst buffer. Advances reader.
         pub fn copy(self: *Self, dst: []u8) ![]const u8 {
-            if (self.read(dst.len)) |buf| {
-                // fastpath we already have enough bytes in buffer
-                @memcpy(dst, buf);
-                return dst;
-            }
-            var rdr = self.sliceReader(dst.len);
             var pos: usize = 0;
-            while (try rdr.next()) |slice| : (pos += slice.len) {
+            while (pos < dst.len) {
+                const slice = try self.readSlice(dst.len - pos);
                 @memcpy(dst[pos .. pos + slice.len], slice);
-            }
-            return dst;
-        }
-
-        // Retruns count bytes from buffer and advances for that number of
-        // bytes. If we don't have that much bytes buffered returns null.
-        fn read(self: *Self, count: usize) ?[]const u8 {
-            if (count <= self.end - self.start) {
-                const buf = self.buffer[self.start .. self.start + count];
-                self.advance(count);
-                return buf;
-            }
-            return null;
-        }
-
-        const SliceReader = struct {
-            size: usize,
-            offset: usize,
-            reader: *Self,
-
-            pub fn next(self: *SliceReader) !?[]const u8 {
-                const remaining_size = self.size - self.offset;
-                if (remaining_size == 0) return null;
-                const slice = try self.reader.readSlice(remaining_size);
                 self.advance(slice.len);
-                return slice;
-            }
-
-            fn advance(self: *SliceReader, len: usize) void {
-                self.offset += len;
-                self.reader.advance(len);
+                pos += slice.len;
             }
-        };
-
-        pub fn sliceReader(self: *Self, size: usize) SliceReader {
-            return .{
-                .size = size,
-                .reader = self,
-                .offset = 0,
-            };
+            return dst;
         }
 
         pub fn paxFileReader(self: *Self, size: usize) PaxFileReader {
@@ -388,9 +335,6 @@ fn BufferedReader(comptime ReaderType: type) type {
             // Caller of the next has to call value in PaxAttribute, to advance
             // reader across value.
             pub fn next(self: *PaxFileReader) !?PaxAttribute {
-                const rdr = self.reader;
-                _ = rdr;
-
                 while (true) {
                     const remaining_size = self.size - self.offset;
                     if (remaining_size == 0) return null;
@@ -433,10 +377,14 @@ fn Iterator(comptime ReaderType: type) type {
     return struct {
         // scratch buffer for file attributes
         scratch: struct {
-            // size: two paths (name and link_name) and size (24 in pax attribute)
+            // size: two paths (name and link_name) and files size bytes (24 in pax attribute)
             buffer: [std.fs.MAX_PATH_BYTES * 2 + 24]u8 = undefined,
             tail: usize = 0,
 
+            name: []const u8 = undefined,
+            link_name: []const u8 = undefined,
+            size: usize = 0,
+
             // Allocate size of the buffer for some attribute.
             fn alloc(self: *@This(), size: usize) ![]u8 {
                 const free_size = self.buffer.len - self.tail;
@@ -447,45 +395,53 @@ fn Iterator(comptime ReaderType: type) type {
                 return self.buffer[head..self.tail];
             }
 
-            // Free whole buffer.
-            fn free(self: *@This()) void {
+            // Reset buffer and all fields.
+            fn reset(self: *@This()) void {
                 self.tail = 0;
+                self.name = self.buffer[0..0];
+                self.link_name = self.buffer[0..0];
+                self.size = 0;
+            }
+
+            fn append(self: *@This(), header: Header) !void {
+                if (self.size == 0) self.size = try header.fileSize();
+                if (self.link_name.len == 0) {
+                    const link_name = header.linkName();
+                    if (link_name.len > 0) {
+                        const buf = try self.alloc(link_name.len);
+                        @memcpy(buf, link_name);
+                        self.link_name = buf;
+                    }
+                }
+                if (self.name.len == 0) {
+                    self.name = try header.fullName((try self.alloc(MAX_HEADER_NAME_SIZE))[0..MAX_HEADER_NAME_SIZE]);
+                }
             }
         } = .{},
 
         reader: BufferedReaderType,
         diagnostics: ?*Options.Diagnostics,
-        padding: usize = 0, // bytes of file padding
+        padding: usize = 0, // bytes of padding to the end of the block
 
         const Self = @This();
 
-        const File = struct {
-            name: []const u8 = &[_]u8{},
-            link_name: []const u8 = &[_]u8{},
-            size: usize = 0,
-            file_type: Header.FileType = .normal,
+        pub const File = struct {
+            name: []const u8, // name of file, symlink or directory
+            link_name: []const u8, // target name of symlink
+            size: usize, // size of the file in bytes
+            file_type: Header.FileType,
+
             reader: *BufferedReaderType,
 
+            // Writes file content to writer.
             pub fn write(self: File, writer: anytype) !void {
                 try self.reader.write(writer, self.size);
             }
 
+            // Skips file content. Advances reader.
             pub fn skip(self: File) !void {
                 try self.reader.skip(self.size);
             }
-
-            fn chksum(self: File) ![16]u8 {
-                var sum = [_]u8{0} ** 16;
-                if (self.size == 0) return sum;
-
-                var rdr = self.reader.sliceReader(self.size);
-                var h = std.crypto.hash.Md5.init(.{});
-                while (try rdr.next()) |slice| {
-                    h.update(slice);
-                }
-                h.final(&sum);
-                return sum;
-            }
         };
 
         // Externally, `next` iterates through the tar archive as if it is a
@@ -495,62 +451,62 @@ fn Iterator(comptime ReaderType: type) type {
         // loop iterates through one or more "header files" until it finds a
         // "normal file".
         pub fn next(self: *Self) !?File {
-            var file: File = .{ .reader = &self.reader };
-            self.scratch.free();
+            self.scratch.reset();
 
-            while (try self.reader.readBlock(self.padding)) |block_bytes| {
+            while (try self.reader.readHeader(self.padding)) |block_bytes| {
                 const header = Header{ .bytes = block_bytes[0..BLOCK_SIZE] };
                 if (try header.checkChksum() == 0) return null; // zero block found
 
                 const file_type = header.fileType();
-                const file_size = try header.fileSize();
-                self.padding = filePadding(file_size);
+                const size: usize = @intCast(try header.fileSize());
+                self.padding = blockPadding(size);
 
                 switch (file_type) {
-                    // file types to retrun from next
+                    // File types to retrun upstream
                     .directory, .normal, .symbolic_link => {
-                        if (file.size == 0) file.size = file_size;
-                        self.padding = filePadding(file.size);
-
-                        if (file.name.len == 0)
-                            file.name = try header.fullFileName((try self.scratch.alloc(std.fs.MAX_PATH_BYTES))[0..std.fs.MAX_PATH_BYTES]);
-                        if (file.link_name.len == 0) file.link_name = header.linkName();
-                        file.file_type = file_type;
+                        try self.scratch.append(header);
+                        const file = File{
+                            .file_type = file_type,
+                            .name = self.scratch.name,
+                            .link_name = self.scratch.link_name,
+                            .size = self.scratch.size,
+                            .reader = &self.reader,
+                        };
+                        self.padding = blockPadding(file.size);
                         return file;
                     },
-                    // prefix header types
+                    // Prefix header types
                     .gnu_long_name => {
-                        file.name = nullStr(try self.reader.copy(try self.scratch.alloc(file_size)));
+                        self.scratch.name = nullStr(try self.reader.copy(try self.scratch.alloc(size)));
                     },
                     .gnu_long_link => {
-                        file.link_name = nullStr(try self.reader.copy(try self.scratch.alloc(file_size)));
+                        self.scratch.link_name = nullStr(try self.reader.copy(try self.scratch.alloc(size)));
                     },
                     .extended_header => {
-                        if (file_size == 0) continue;
-                        // use just last extended header data
-                        self.scratch.free();
-                        file = File{ .reader = &self.reader };
+                        if (size == 0) continue;
+                        // Use just attributes from last extended header.
+                        self.scratch.reset();
 
-                        var rdr = self.reader.paxFileReader(file_size);
+                        var rdr = self.reader.paxFileReader(size);
                         while (try rdr.next()) |attr| {
                             switch (attr.key) {
                                 .path => {
-                                    file.name = try noNull(try attr.value(try self.scratch.alloc(attr.value_len)));
+                                    self.scratch.name = try noNull(try attr.value(try self.scratch.alloc(attr.value_len)));
                                 },
                                 .linkpath => {
-                                    file.link_name = try noNull(try attr.value(try self.scratch.alloc(attr.value_len)));
+                                    self.scratch.link_name = try noNull(try attr.value(try self.scratch.alloc(attr.value_len)));
                                 },
                                 .size => {
-                                    file.size = try std.fmt.parseInt(usize, try attr.value(try self.scratch.alloc(attr.value_len)), 10);
+                                    self.scratch.size = try std.fmt.parseInt(usize, try attr.value(try self.scratch.alloc(attr.value_len)), 10);
                                 },
                             }
                         }
                     },
-                    // ignored header types
+                    // Ignored header type
                     .global_extended_header => {
-                        self.reader.skip(file_size) catch return error.TarHeadersTooBig;
+                        self.reader.skip(size) catch return error.TarHeadersTooBig;
                     },
-                    // unsupported header types
+                    // All other are unsupported header types
                     else => {
                         const d = self.diagnostics orelse return error.TarUnsupportedFileType;
                         try d.errors.append(d.allocator, .{ .unsupported_file_type = .{
@@ -1053,16 +1009,31 @@ test "tar: Go test cases" {
             try std.testing.expectEqualStrings(expected.link_name, actual.link_name);
 
             if (case.chksums.len > i) {
-                var actual_chksum = try actual.chksum();
-                var hex_to_bytes_buffer: [16]u8 = undefined;
-                const expected_chksum = try std.fmt.hexToBytes(&hex_to_bytes_buffer, case.chksums[i]);
-                // std.debug.print("actual chksum: {s}\n", .{std.fmt.fmtSliceHexLower(&actual_chksum)});
-                try std.testing.expectEqualStrings(expected_chksum, &actual_chksum);
+                var md5writer = Md5Writer{};
+                try actual.write(&md5writer);
+                const chksum = md5writer.chksum();
+                // std.debug.print("actual chksum: {s}\n", .{chksum});
+                try std.testing.expectEqualStrings(case.chksums[i], &chksum);
             } else {
                 if (!expected.truncated) try actual.skip(); // skip file content
             }
-            i += 1;
         }
         try std.testing.expectEqual(case.files.len, i);
     }
 }
+
+// used in test to calculate file chksum
+const Md5Writer = struct {
+    h: std.crypto.hash.Md5 = std.crypto.hash.Md5.init(.{}),
+
+    pub fn writeAll(self: *Md5Writer, buf: []const u8) !void {
+        self.h.update(buf);
+    }
+
+    pub fn chksum(self: *Md5Writer) [32]u8 {
+        var s = [_]u8{0} ** 16;
+        self.h.final(&s);
+        return std.fmt.bytesToHex(s, .lower);
+    }
+};
+