Commit 959d227d13

Ryan Liptak <squeek502@hotmail.com>
2024-07-12 09:38:10
ArgIteratorWindows: Reduce allocated memory by parsing the WTF-16 string directly
Before this commit, the WTF-16 command line string would be converted to WTF-8 in `init`, and then a second buffer of the WTF-8 size + 1 would be allocated to store the parsed arguments. The converted WTF-8 command line would then be parsed and the relevant bytes would be copied into the argument buffer before being returned. After this commit, only the WTF-8 size of the WTF-16 string is calculated (without conversion) which is then used to allocate the buffer for the parsed arguments. Parsing is then done on the WTF-16 slice directly, with the arguments being converted to WTF-8 on-the-fly. This has a few (minor) benefits: - Cuts the amount of memory allocated by ArgIteratorWindows in half (or better) - Makes the total amount of memory allocated by ArgIteratorWindows predictable, since, before, the upfront `wtf16LeToWtf8Alloc` call could end up allocating more-memory-than-necessary temporarily due to its internal use of an ArrayList. Now, the amount of memory allocated is always exactly `calcWtf8Len(cmd_line) + 1`.
1 parent 11534aa
Changed files (2)
lib/std/process.zig
@@ -663,11 +663,11 @@ pub const ArgIteratorWasi = struct {
 /// - https://daviddeley.com/autohotkey/parameters/parameters.htm#WINCRULES
 pub const ArgIteratorWindows = struct {
     allocator: Allocator,
-    /// Owned by the iterator.
-    /// Encoded as WTF-8.
-    cmd_line: []const u8,
+    /// Encoded as WTF-16 LE.
+    cmd_line: [:0]const u16,
     index: usize = 0,
-    /// Owned by the iterator. Long enough to hold the entire `cmd_line` plus a null terminator.
+    /// Owned by the iterator. Long enough to hold contiguous NUL-terminated slices
+    /// of each argument encoded as WTF-8.
     buffer: []u8,
     start: usize = 0,
     end: usize = 0,
@@ -676,13 +676,18 @@ pub const ArgIteratorWindows = struct {
 
     /// `cmd_line_w` *must* be a WTF16-LE-encoded string.
     ///
-    /// The iterator makes a copy of `cmd_line_w` converted WTF-8 and keeps it; it does *not* take
-    /// ownership of `cmd_line_w`.
+    /// The iterator stores and uses `cmd_line_w`, so its memory must be valid for
+    /// at least as long as the returned ArgIteratorWindows.
     pub fn init(allocator: Allocator, cmd_line_w: [*:0]const u16) InitError!ArgIteratorWindows {
-        const cmd_line = try unicode.wtf16LeToWtf8Alloc(allocator, mem.sliceTo(cmd_line_w, 0));
-        errdefer allocator.free(cmd_line);
-
-        const buffer = try allocator.alloc(u8, cmd_line.len + 1);
+        const cmd_line = mem.sliceTo(cmd_line_w, 0);
+        const wtf8_len = unicode.calcWtf8Len(cmd_line);
+
+        // This buffer must be large enough to contain contiguous NUL-terminated slices
+        // of each argument. For arguments past the first one, space for the NUL-terminator
+        // is guaranteed due to the necessary whitespace between arugments. However, we need
+        // one extra byte to guarantee enough room for the NUL terminator if the command line
+        // ends up being exactly 1 argument long with no quotes, etc.
+        const buffer = try allocator.alloc(u8, wtf8_len + 1);
         errdefer allocator.free(buffer);
 
         return .{
@@ -714,11 +719,11 @@ pub const ArgIteratorWindows = struct {
             for (0..count) |_| emitCharacter(self, '\\');
         }
 
-        fn emitCharacter(self: *ArgIteratorWindows, char: u8) void {
-            self.buffer[self.end] = char;
-            self.end += 1;
+        fn emitCharacter(self: *ArgIteratorWindows, code_unit: u16) void {
+            const wtf8_len = std.unicode.wtf8Encode(code_unit, self.buffer[self.end..]) catch unreachable;
+            self.end += wtf8_len;
 
-            // Because we are emitting WTF-8 byte-by-byte, we need to
+            // Because we are emitting WTF-8, we need to
             // check to see if we've emitted two consecutive surrogate
             // codepoints that form a valid surrogate pair in order
             // to ensure that we're always emitting well-formed WTF-8
@@ -732,9 +737,7 @@ pub const ArgIteratorWindows = struct {
             // This is relevant when dealing with a WTF-16 encoded
             // command line like this:
             // "<0xD801>"<0xDC37>
-            // which would get converted to WTF-8 in `cmd_line` as:
-            // "<0xED><0xA0><0x81>"<0xED><0xB0><0xB7>
-            // and then after parsing it'd naively get emitted as:
+            // which would get parsed and converted to WTF-8 as:
             // <0xED><0xA0><0x81><0xED><0xB0><0xB7>
             // but instead, we need to recognize the surrogate pair
             // and emit the codepoint it encodes, which in this
@@ -780,7 +783,7 @@ pub const ArgIteratorWindows = struct {
 
         fn emitBackslashes(_: *ArgIteratorWindows, _: usize) void {}
 
-        fn emitCharacter(_: *ArgIteratorWindows, _: u8) void {}
+        fn emitCharacter(_: *ArgIteratorWindows, _: u16) void {}
 
         fn yieldArg(_: *ArgIteratorWindows) bool {
             return true;
@@ -798,7 +801,10 @@ pub const ArgIteratorWindows = struct {
 
             var inside_quotes = false;
             while (true) : (self.index += 1) {
-                const char = if (self.index != self.cmd_line.len) self.cmd_line[self.index] else 0;
+                const char = if (self.index != self.cmd_line.len)
+                    mem.littleToNative(u16, self.cmd_line[self.index])
+                else
+                    0;
                 switch (char) {
                     0 => {
                         return strategy.yieldArg(self);
@@ -823,7 +829,10 @@ pub const ArgIteratorWindows = struct {
 
         // Skip spaces and tabs. The iterator completes if we reach the end of the string here.
         while (true) : (self.index += 1) {
-            const char = if (self.index != self.cmd_line.len) self.cmd_line[self.index] else 0;
+            const char = if (self.index != self.cmd_line.len)
+                mem.littleToNative(u16, self.cmd_line[self.index])
+            else
+                0;
             switch (char) {
                 0 => return strategy.eof,
                 ' ', '\t' => continue,
@@ -844,7 +853,10 @@ pub const ArgIteratorWindows = struct {
         var backslash_count: usize = 0;
         var inside_quotes = false;
         while (true) : (self.index += 1) {
-            const char = if (self.index != self.cmd_line.len) self.cmd_line[self.index] else 0;
+            const char = if (self.index != self.cmd_line.len)
+                mem.littleToNative(u16, self.cmd_line[self.index])
+            else
+                0;
             switch (char) {
                 0 => {
                     strategy.emitBackslashes(self, backslash_count);
@@ -867,7 +879,7 @@ pub const ArgIteratorWindows = struct {
                     } else {
                         if (inside_quotes and
                             self.index + 1 != self.cmd_line.len and
-                            self.cmd_line[self.index + 1] == '"')
+                            mem.littleToNative(u16, self.cmd_line[self.index + 1]) == '"')
                         {
                             strategy.emitCharacter(self, '"');
                             self.index += 1;
@@ -892,7 +904,6 @@ pub const ArgIteratorWindows = struct {
     /// argument slices.
     pub fn deinit(self: *ArgIteratorWindows) void {
         self.allocator.free(self.buffer);
-        self.allocator.free(self.cmd_line);
     }
 };
 
lib/std/unicode.zig
@@ -2107,3 +2107,36 @@ test "well-formed WTF-16 roundtrips" {
         mem.nativeToLittle(u16, 0xDC00), // low surrogate
     });
 }
+
+/// Returns the length, in bytes, that would be necessary to encode the
+/// given WTF-16 LE slice as WTF-8.
+pub fn calcWtf8Len(wtf16le: []const u16) usize {
+    var it = Wtf16LeIterator.init(wtf16le);
+    var num_wtf8_bytes: usize = 0;
+    while (it.nextCodepoint()) |codepoint| {
+        // Note: If utf8CodepointSequenceLength is ever changed to error on surrogate
+        // codepoints, then it would no longer be eligible to be used in this context.
+        num_wtf8_bytes += utf8CodepointSequenceLength(codepoint) catch |err| switch (err) {
+            error.CodepointTooLarge => unreachable,
+        };
+    }
+    return num_wtf8_bytes;
+}
+
+fn testCalcWtf8Len() !void {
+    const L = utf8ToUtf16LeStringLiteral;
+    try testing.expectEqual(@as(usize, 1), calcWtf8Len(L("a")));
+    try testing.expectEqual(@as(usize, 10), calcWtf8Len(L("abcdefghij")));
+    // unpaired surrogate
+    try testing.expectEqual(@as(usize, 3), calcWtf8Len(&[_]u16{
+        mem.nativeToLittle(u16, 0xD800),
+    }));
+    try testing.expectEqual(@as(usize, 15), calcWtf8Len(L("こんにちは")));
+    // First codepoints that are encoded as 1, 2, 3, and 4 bytes
+    try testing.expectEqual(@as(usize, 1 + 2 + 3 + 4), calcWtf8Len(L("\u{0}\u{80}\u{800}\u{10000}")));
+}
+
+test "calculate wtf8 string length of given wtf16 string" {
+    try testCalcWtf8Len();
+    try comptime testCalcWtf8Len();
+}