Commit 600b652825

Andrew Kelley <andrew@ziglang.org>
2024-04-24 22:49:43
Merge pull request #19698 from squeek502/windows-batbadbut
std.process.Child: Mitigate arbitrary command execution vulnerability on Windows (BatBadBut)
1 parent e36bf2b
Changed files (8)
lib/std/os/windows/kernel32.zig
@@ -243,6 +243,8 @@ pub extern "kernel32" fn GetSystemInfo(lpSystemInfo: *SYSTEM_INFO) callconv(WINA
 pub extern "kernel32" fn GetSystemTimeAsFileTime(*FILETIME) callconv(WINAPI) void;
 pub extern "kernel32" fn IsProcessorFeaturePresent(ProcessorFeature: DWORD) BOOL;
 
+pub extern "kernel32" fn GetSystemDirectoryW(lpBuffer: LPWSTR, uSize: UINT) callconv(WINAPI) UINT;
+
 pub extern "kernel32" fn HeapCreate(flOptions: DWORD, dwInitialSize: SIZE_T, dwMaximumSize: SIZE_T) callconv(WINAPI) ?HANDLE;
 pub extern "kernel32" fn HeapDestroy(hHeap: HANDLE) callconv(WINAPI) BOOL;
 pub extern "kernel32" fn HeapReAlloc(hHeap: HANDLE, dwFlags: DWORD, lpMem: *anyopaque, dwBytes: SIZE_T) callconv(WINAPI) ?*anyopaque;
lib/std/child_process.zig
@@ -136,6 +136,14 @@ pub const ChildProcess = struct {
 
         /// Windows-only. `cwd` was provided, but the path did not exist when spawning the child process.
         CurrentWorkingDirectoryUnlinked,
+
+        /// Windows-only. NUL (U+0000), LF (U+000A), CR (U+000D) are not allowed
+        /// within arguments when executing a `.bat`/`.cmd` script.
+        /// - NUL/LF signifiies end of arguments, so anything afterwards
+        ///   would be lost after execution.
+        /// - CR is stripped by `cmd.exe`, so any CR codepoints
+        ///   would be lost after execution.
+        InvalidBatchScriptArg,
     } ||
         posix.ExecveError ||
         posix.SetIdError ||
@@ -814,17 +822,20 @@ pub const ChildProcess = struct {
         const app_name_w = try unicode.wtf8ToWtf16LeAllocZ(self.allocator, app_basename_wtf8);
         defer self.allocator.free(app_name_w);
 
-        const cmd_line_w = argvToCommandLineWindows(self.allocator, self.argv) catch |err| switch (err) {
-            // argv[0] contains unsupported characters that will never resolve to a valid exe.
-            error.InvalidArg0 => return error.FileNotFound,
-            else => |e| return e,
-        };
-        defer self.allocator.free(cmd_line_w);
-
         run: {
             const PATH: [:0]const u16 = std.process.getenvW(unicode.utf8ToUtf16LeStringLiteral("PATH")) orelse &[_:0]u16{};
             const PATHEXT: [:0]const u16 = std.process.getenvW(unicode.utf8ToUtf16LeStringLiteral("PATHEXT")) orelse &[_:0]u16{};
 
+            // In case the command ends up being a .bat/.cmd script, we need to escape things using the cmd.exe rules
+            // and invoke cmd.exe ourselves in order to mitigate arbitrary command execution from maliciously
+            // constructed arguments.
+            //
+            // We'll need to wait until we're actually trying to run the command to know for sure
+            // if the resolved command has the `.bat` or `.cmd` extension, so we defer actually
+            // serializing the command line until we determine how it should be serialized.
+            var cmd_line_cache = WindowsCommandLineCache.init(self.allocator, self.argv);
+            defer cmd_line_cache.deinit();
+
             var app_buf = std.ArrayListUnmanaged(u16){};
             defer app_buf.deinit(self.allocator);
 
@@ -846,8 +857,10 @@ pub const ChildProcess = struct {
                 dir_buf.shrinkRetainingCapacity(normalized_len);
             }
 
-            windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, cmd_line_w.ptr, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo) catch |no_path_err| {
+            windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, &cmd_line_cache, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo) catch |no_path_err| {
                 const original_err = switch (no_path_err) {
+                    // argv[0] contains unsupported characters that will never resolve to a valid exe.
+                    error.InvalidArg0 => return error.FileNotFound,
                     error.FileNotFound, error.InvalidExe, error.AccessDenied => |e| e,
                     error.UnrecoverableInvalidExe => return error.InvalidExe,
                     else => |e| return e,
@@ -872,9 +885,11 @@ pub const ChildProcess = struct {
                     const normalized_len = windows.normalizePath(u16, dir_buf.items) catch continue;
                     dir_buf.shrinkRetainingCapacity(normalized_len);
 
-                    if (windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, cmd_line_w.ptr, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo)) {
+                    if (windowsCreateProcessPathExt(self.allocator, &dir_buf, &app_buf, PATHEXT, &cmd_line_cache, envp_ptr, cwd_w_ptr, &siStartInfo, &piProcInfo)) {
                         break :run;
                     } else |err| switch (err) {
+                        // argv[0] contains unsupported characters that will never resolve to a valid exe.
+                        error.InvalidArg0 => return error.FileNotFound,
                         error.FileNotFound, error.AccessDenied, error.InvalidExe => continue,
                         error.UnrecoverableInvalidExe => return error.InvalidExe,
                         else => |e| return e,
@@ -935,7 +950,7 @@ fn windowsCreateProcessPathExt(
     dir_buf: *std.ArrayListUnmanaged(u16),
     app_buf: *std.ArrayListUnmanaged(u16),
     pathext: [:0]const u16,
-    cmd_line: [*:0]u16,
+    cmd_line_cache: *WindowsCommandLineCache,
     envp_ptr: ?[*]u16,
     cwd_ptr: ?[*:0]u16,
     lpStartupInfo: *windows.STARTUPINFOW,
@@ -1069,7 +1084,26 @@ fn windowsCreateProcessPathExt(
             try dir_buf.append(allocator, 0);
             const full_app_name = dir_buf.items[0 .. dir_buf.items.len - 1 :0];
 
-            if (windowsCreateProcess(full_app_name.ptr, cmd_line, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| {
+            const is_bat_or_cmd = bat_or_cmd: {
+                const app_name = app_buf.items[0..app_name_len];
+                const ext_start = std.mem.lastIndexOfScalar(u16, app_name, '.') orelse break :bat_or_cmd false;
+                const ext = app_name[ext_start..];
+                const ext_enum = windowsCreateProcessSupportsExtension(ext) orelse break :bat_or_cmd false;
+                switch (ext_enum) {
+                    .cmd, .bat => break :bat_or_cmd true,
+                    else => break :bat_or_cmd false,
+                }
+            };
+            const cmd_line_w = if (is_bat_or_cmd)
+                try cmd_line_cache.scriptCommandLine(full_app_name)
+            else
+                try cmd_line_cache.commandLine();
+            const app_name_w = if (is_bat_or_cmd)
+                try cmd_line_cache.cmdExePath()
+            else
+                full_app_name;
+
+            if (windowsCreateProcess(app_name_w.ptr, cmd_line_w.ptr, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| {
                 return;
             } else |err| switch (err) {
                 error.FileNotFound,
@@ -1111,7 +1145,20 @@ fn windowsCreateProcessPathExt(
         try dir_buf.append(allocator, 0);
         const full_app_name = dir_buf.items[0 .. dir_buf.items.len - 1 :0];
 
-        if (windowsCreateProcess(full_app_name.ptr, cmd_line, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| {
+        const is_bat_or_cmd = switch (ext_enum) {
+            .cmd, .bat => true,
+            else => false,
+        };
+        const cmd_line_w = if (is_bat_or_cmd)
+            try cmd_line_cache.scriptCommandLine(full_app_name)
+        else
+            try cmd_line_cache.commandLine();
+        const app_name_w = if (is_bat_or_cmd)
+            try cmd_line_cache.cmdExePath()
+        else
+            full_app_name;
+
+        if (windowsCreateProcess(app_name_w.ptr, cmd_line_w.ptr, envp_ptr, cwd_ptr, lpStartupInfo, lpProcessInformation)) |_| {
             return;
         } else |err| switch (err) {
             error.FileNotFound => continue,
@@ -1236,6 +1283,223 @@ test windowsCreateProcessSupportsExtension {
     try std.testing.expect(windowsCreateProcessSupportsExtension(&[_]u16{ '.', 'e', 'X', 'e', 'c' }) == null);
 }
 
+/// Serializes argv into a WTF-16 encoded command-line string for use with CreateProcessW.
+///
+/// Serialization is done on-demand and the result is cached in order to allow for:
+/// - Only serializing the particular type of command line needed (`.bat`/`.cmd`
+///   command line serialization is different from `.exe`/etc)
+/// - Reusing the serialized command lines if necessary (i.e. if the execution
+///   of a command fails and the PATH is going to be continued to be searched
+///   for more candidates)
+pub const WindowsCommandLineCache = struct {
+    cmd_line: ?[:0]u16 = null,
+    script_cmd_line: ?[:0]u16 = null,
+    cmd_exe_path: ?[:0]u16 = null,
+    argv: []const []const u8,
+    allocator: mem.Allocator,
+
+    pub fn init(allocator: mem.Allocator, argv: []const []const u8) WindowsCommandLineCache {
+        return .{
+            .allocator = allocator,
+            .argv = argv,
+        };
+    }
+
+    pub fn deinit(self: *WindowsCommandLineCache) void {
+        if (self.cmd_line) |cmd_line| self.allocator.free(cmd_line);
+        if (self.script_cmd_line) |script_cmd_line| self.allocator.free(script_cmd_line);
+        if (self.cmd_exe_path) |cmd_exe_path| self.allocator.free(cmd_exe_path);
+    }
+
+    pub fn commandLine(self: *WindowsCommandLineCache) ![:0]u16 {
+        if (self.cmd_line == null) {
+            self.cmd_line = try argvToCommandLineWindows(self.allocator, self.argv);
+        }
+        return self.cmd_line.?;
+    }
+
+    /// Not cached, since the path to the batch script will change during PATH searching.
+    /// `script_path` should be as qualified as possible, e.g. if the PATH is being searched,
+    /// then script_path should include both the search path and the script filename
+    /// (this allows avoiding cmd.exe having to search the PATH again).
+    pub fn scriptCommandLine(self: *WindowsCommandLineCache, script_path: []const u16) ![:0]u16 {
+        if (self.script_cmd_line) |v| self.allocator.free(v);
+        self.script_cmd_line = try argvToScriptCommandLineWindows(
+            self.allocator,
+            script_path,
+            self.argv[1..],
+        );
+        return self.script_cmd_line.?;
+    }
+
+    pub fn cmdExePath(self: *WindowsCommandLineCache) ![:0]u16 {
+        if (self.cmd_exe_path == null) {
+            self.cmd_exe_path = try windowsCmdExePath(self.allocator);
+        }
+        return self.cmd_exe_path.?;
+    }
+};
+
+pub fn windowsCmdExePath(allocator: mem.Allocator) error{ OutOfMemory, Unexpected }![:0]u16 {
+    var buf = try std.ArrayListUnmanaged(u16).initCapacity(allocator, 128);
+    errdefer buf.deinit(allocator);
+    while (true) {
+        const unused_slice = buf.unusedCapacitySlice();
+        // TODO: Get the system directory from PEB.ReadOnlyStaticServerData
+        const len = windows.kernel32.GetSystemDirectoryW(@ptrCast(unused_slice), @intCast(unused_slice.len));
+        if (len == 0) {
+            switch (windows.kernel32.GetLastError()) {
+                else => |err| return windows.unexpectedError(err),
+            }
+        }
+        if (len > unused_slice.len) {
+            try buf.ensureUnusedCapacity(allocator, len);
+        } else {
+            buf.items.len = len;
+            break;
+        }
+    }
+    switch (buf.items[buf.items.len - 1]) {
+        '/', '\\' => {},
+        else => try buf.append(allocator, fs.path.sep),
+    }
+    try buf.appendSlice(allocator, std.unicode.utf8ToUtf16LeStringLiteral("cmd.exe"));
+    return try buf.toOwnedSliceSentinel(allocator, 0);
+}
+
+pub const ArgvToScriptCommandLineError = error{
+    OutOfMemory,
+    InvalidWtf8,
+    /// NUL (U+0000), LF (U+000A), CR (U+000D) are not allowed
+    /// within arguments when executing a `.bat`/`.cmd` script.
+    /// - NUL/LF signifiies end of arguments, so anything afterwards
+    ///   would be lost after execution.
+    /// - CR is stripped by `cmd.exe`, so any CR codepoints
+    ///   would be lost after execution.
+    InvalidBatchScriptArg,
+};
+
+/// Serializes `argv` to a Windows command-line string that uses `cmd.exe /c` and `cmd.exe`-specific
+/// escaping rules. The caller owns the returned slice.
+///
+/// Escapes `argv` using the suggested mitigation against arbitrary command execution from:
+/// https://flatt.tech/research/posts/batbadbut-you-cant-securely-execute-commands-on-windows/
+pub fn argvToScriptCommandLineWindows(
+    allocator: mem.Allocator,
+    /// Path to the `.bat`/`.cmd` script. If this path is relative, it is assumed to be relative to the CWD.
+    /// The script must have been verified to exist at this path before calling this function.
+    script_path: []const u16,
+    /// Arguments, not including the script name itself. Expected to be encoded as WTF-8.
+    script_args: []const []const u8,
+) ArgvToScriptCommandLineError![:0]u16 {
+    var buf = try std.ArrayList(u8).initCapacity(allocator, 64);
+    defer buf.deinit();
+
+    // `/d` disables execution of AutoRun commands.
+    // `/e:ON` and `/v:OFF` are needed for BatBadBut mitigation:
+    // > If delayed expansion is enabled via the registry value DelayedExpansion,
+    // > it must be disabled by explicitly calling cmd.exe with the /V:OFF option.
+    // > Escaping for % requires the command extension to be enabled.
+    // > If it’s disabled via the registry value EnableExtensions, it must be enabled with the /E:ON option.
+    // https://flatt.tech/research/posts/batbadbut-you-cant-securely-execute-commands-on-windows/
+    buf.appendSliceAssumeCapacity("cmd.exe /d /e:ON /v:OFF /c \"");
+
+    // Always quote the path to the script arg
+    buf.appendAssumeCapacity('"');
+    // We always want the path to the batch script to include a path separator in order to
+    // avoid cmd.exe searching the PATH for the script. This is not part of the arbitrary
+    // command execution mitigation, we just know exactly what script we want to execute
+    // at this point, and potentially making cmd.exe re-find it is unnecessary.
+    //
+    // If the script path does not have a path separator, then we know its relative to CWD and
+    // we can just put `.\` in the front.
+    if (mem.indexOfAny(u16, script_path, &[_]u16{ mem.nativeToLittle(u16, '\\'), mem.nativeToLittle(u16, '/') }) == null) {
+        try buf.appendSlice(".\\");
+    }
+    // Note that we don't do any escaping/mitigations for this argument, since the relevant
+    // characters (", %, etc) are illegal in file paths and this function should only be called
+    // with script paths that have been verified to exist.
+    try std.unicode.wtf16LeToWtf8ArrayList(&buf, script_path);
+    buf.appendAssumeCapacity('"');
+
+    for (script_args) |arg| {
+        // Literal carriage returns get stripped when run through cmd.exe
+        // and NUL/newlines act as 'end of command.' Because of this, it's basically
+        // always a mistake to include these characters in argv, so it's
+        // an error condition in order to ensure that the return of this
+        // function can always roundtrip through cmd.exe.
+        if (std.mem.indexOfAny(u8, arg, "\x00\r\n") != null) {
+            return error.InvalidBatchScriptArg;
+        }
+
+        // Separate args with a space.
+        try buf.append(' ');
+
+        // Need to quote if the argument is empty (otherwise the arg would just be lost)
+        // or if the last character is a `\`, since then something like "%~2" in a .bat
+        // script would cause the closing " to be escaped which we don't want.
+        var needs_quotes = arg.len == 0 or arg[arg.len - 1] == '\\';
+        if (!needs_quotes) {
+            for (arg) |c| {
+                switch (c) {
+                    // Known good characters that don't need to be quoted
+                    'A'...'Z', 'a'...'z', '0'...'9', '#', '$', '*', '+', '-', '.', '/', ':', '?', '@', '\\', '_' => {},
+                    // When in doubt, quote
+                    else => {
+                        needs_quotes = true;
+                        break;
+                    },
+                }
+            }
+        }
+        if (needs_quotes) {
+            try buf.append('"');
+        }
+        var backslashes: usize = 0;
+        for (arg) |c| {
+            switch (c) {
+                '\\' => {
+                    backslashes += 1;
+                },
+                '"' => {
+                    try buf.appendNTimes('\\', backslashes);
+                    try buf.append('"');
+                    backslashes = 0;
+                },
+                // Replace `%` with `%%cd:~,%`.
+                //
+                // cmd.exe allows extracting a substring from an environment
+                // variable with the syntax: `%foo:~<start_index>,<end_index>%`.
+                // Therefore, `%cd:~,%` will always expand to an empty string
+                // since both the start and end index are blank, and it is assumed
+                // that `%cd%` is always available since it is a built-in variable
+                // that corresponds to the current directory.
+                //
+                // This means that replacing `%foo%` with `%%cd:~,%foo%%cd:~,%`
+                // will stop `%foo%` from being expanded and *after* expansion
+                // we'll still be left with `%foo%` (the literal string).
+                '%' => {
+                    // the trailing `%` is appended outside the switch
+                    try buf.appendSlice("%%cd:~,");
+                    backslashes = 0;
+                },
+                else => {
+                    backslashes = 0;
+                },
+            }
+            try buf.append(c);
+        }
+        if (needs_quotes) {
+            try buf.appendNTimes('\\', backslashes);
+            try buf.append('"');
+        }
+    }
+
+    try buf.append('"');
+
+    return try unicode.wtf8ToWtf16LeAllocZ(allocator, buf.items);
+}
+
 pub const ArgvToCommandLineError = error{ OutOfMemory, InvalidWtf8, InvalidArg0 };
 
 /// Serializes `argv` to a Windows command-line string suitable for passing to a child process and
lib/std/unicode.zig
@@ -934,7 +934,7 @@ fn utf16LeToUtf8ArrayListImpl(
     .cannot_encode_surrogate_half => Utf16LeToUtf8AllocError,
     .can_encode_surrogate_half => mem.Allocator.Error,
 })!void {
-    assert(result.capacity >= utf16le.len);
+    assert(result.unusedCapacitySlice().len >= utf16le.len);
 
     var remaining = utf16le;
     vectorized: {
@@ -979,7 +979,7 @@ fn utf16LeToUtf8ArrayListImpl(
 pub const Utf16LeToUtf8AllocError = mem.Allocator.Error || Utf16LeToUtf8Error;
 
 pub fn utf16LeToUtf8ArrayList(result: *std.ArrayList(u8), utf16le: []const u16) Utf16LeToUtf8AllocError!void {
-    try result.ensureTotalCapacityPrecise(utf16le.len);
+    try result.ensureUnusedCapacity(utf16le.len);
     return utf16LeToUtf8ArrayListImpl(result, utf16le, .cannot_encode_surrogate_half);
 }
 
@@ -1138,7 +1138,7 @@ test utf16LeToUtf8 {
 }
 
 fn utf8ToUtf16LeArrayListImpl(result: *std.ArrayList(u16), utf8: []const u8, comptime surrogates: Surrogates) !void {
-    assert(result.capacity >= utf8.len);
+    assert(result.unusedCapacitySlice().len >= utf8.len);
 
     var remaining = utf8;
     vectorized: {
@@ -1176,7 +1176,7 @@ fn utf8ToUtf16LeArrayListImpl(result: *std.ArrayList(u16), utf8: []const u8, com
 }
 
 pub fn utf8ToUtf16LeArrayList(result: *std.ArrayList(u16), utf8: []const u8) error{ InvalidUtf8, OutOfMemory }!void {
-    try result.ensureTotalCapacityPrecise(utf8.len);
+    try result.ensureUnusedCapacity(utf8.len);
     return utf8ToUtf16LeArrayListImpl(result, utf8, .cannot_encode_surrogate_half);
 }
 
@@ -1351,6 +1351,64 @@ test utf8ToUtf16LeAllocZ {
     }
 }
 
+test "ArrayList functions on a re-used list" {
+    // utf8ToUtf16LeArrayList
+    {
+        var list = std.ArrayList(u16).init(testing.allocator);
+        defer list.deinit();
+
+        const init_slice = utf8ToUtf16LeStringLiteral("abcdefg");
+        try list.ensureTotalCapacityPrecise(init_slice.len);
+        list.appendSliceAssumeCapacity(init_slice);
+
+        try utf8ToUtf16LeArrayList(&list, "hijklmnopqrstuvwyxz");
+
+        try testing.expectEqualSlices(u16, utf8ToUtf16LeStringLiteral("abcdefghijklmnopqrstuvwyxz"), list.items);
+    }
+
+    // utf16LeToUtf8ArrayList
+    {
+        var list = std.ArrayList(u8).init(testing.allocator);
+        defer list.deinit();
+
+        const init_slice = "abcdefg";
+        try list.ensureTotalCapacityPrecise(init_slice.len);
+        list.appendSliceAssumeCapacity(init_slice);
+
+        try utf16LeToUtf8ArrayList(&list, utf8ToUtf16LeStringLiteral("hijklmnopqrstuvwyxz"));
+
+        try testing.expectEqualStrings("abcdefghijklmnopqrstuvwyxz", list.items);
+    }
+
+    // wtf8ToWtf16LeArrayList
+    {
+        var list = std.ArrayList(u16).init(testing.allocator);
+        defer list.deinit();
+
+        const init_slice = utf8ToUtf16LeStringLiteral("abcdefg");
+        try list.ensureTotalCapacityPrecise(init_slice.len);
+        list.appendSliceAssumeCapacity(init_slice);
+
+        try wtf8ToWtf16LeArrayList(&list, "hijklmnopqrstuvwyxz");
+
+        try testing.expectEqualSlices(u16, utf8ToUtf16LeStringLiteral("abcdefghijklmnopqrstuvwyxz"), list.items);
+    }
+
+    // wtf16LeToWtf8ArrayList
+    {
+        var list = std.ArrayList(u8).init(testing.allocator);
+        defer list.deinit();
+
+        const init_slice = "abcdefg";
+        try list.ensureTotalCapacityPrecise(init_slice.len);
+        list.appendSliceAssumeCapacity(init_slice);
+
+        try wtf16LeToWtf8ArrayList(&list, utf8ToUtf16LeStringLiteral("hijklmnopqrstuvwyxz"));
+
+        try testing.expectEqualStrings("abcdefghijklmnopqrstuvwyxz", list.items);
+    }
+}
+
 /// Converts a UTF-8 string literal into a UTF-16LE string literal.
 pub fn utf8ToUtf16LeStringLiteral(comptime utf8: []const u8) *const [calcUtf16LeLen(utf8) catch |err| @compileError(err):0]u16 {
     return comptime blk: {
@@ -1685,7 +1743,7 @@ pub const Wtf8Iterator = struct {
 };
 
 pub fn wtf16LeToWtf8ArrayList(result: *std.ArrayList(u8), utf16le: []const u16) mem.Allocator.Error!void {
-    try result.ensureTotalCapacityPrecise(utf16le.len);
+    try result.ensureUnusedCapacity(utf16le.len);
     return utf16LeToUtf8ArrayListImpl(result, utf16le, .can_encode_surrogate_half);
 }
 
@@ -1714,7 +1772,7 @@ pub fn wtf16LeToWtf8(wtf8: []u8, wtf16le: []const u16) usize {
 }
 
 pub fn wtf8ToWtf16LeArrayList(result: *std.ArrayList(u16), wtf8: []const u8) error{ InvalidWtf8, OutOfMemory }!void {
-    try result.ensureTotalCapacityPrecise(wtf8.len);
+    try result.ensureUnusedCapacity(wtf8.len);
     return utf8ToUtf16LeArrayListImpl(result, wtf8, .can_encode_surrogate_half);
 }
 
test/standalone/windows_bat_args/build.zig
@@ -0,0 +1,58 @@
+const std = @import("std");
+const builtin = @import("builtin");
+
+pub fn build(b: *std.Build) !void {
+    const test_step = b.step("test", "Test it");
+    b.default_step = test_step;
+
+    const optimize: std.builtin.OptimizeMode = .Debug;
+    const target = b.host;
+
+    if (builtin.os.tag != .windows) return;
+
+    const echo_args = b.addExecutable(.{
+        .name = "echo-args",
+        .root_source_file = b.path("echo-args.zig"),
+        .optimize = optimize,
+        .target = target,
+    });
+
+    const test_exe = b.addExecutable(.{
+        .name = "test",
+        .root_source_file = b.path("test.zig"),
+        .optimize = optimize,
+        .target = target,
+    });
+
+    const run = b.addRunArtifact(test_exe);
+    run.addArtifactArg(echo_args);
+    run.expectExitCode(0);
+    run.skip_foreign_checks = true;
+
+    test_step.dependOn(&run.step);
+
+    const fuzz = b.addExecutable(.{
+        .name = "fuzz",
+        .root_source_file = b.path("fuzz.zig"),
+        .optimize = optimize,
+        .target = target,
+    });
+
+    const fuzz_max_iterations = b.option(u64, "iterations", "The max fuzz iterations (default: 100)") orelse 100;
+    const fuzz_iterations_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_max_iterations}) catch @panic("oom");
+
+    const fuzz_seed = b.option(u64, "seed", "Seed to use for the PRNG (default: random)") orelse seed: {
+        var buf: [8]u8 = undefined;
+        try std.posix.getrandom(&buf);
+        break :seed std.mem.readInt(u64, &buf, builtin.cpu.arch.endian());
+    };
+    const fuzz_seed_arg = std.fmt.allocPrint(b.allocator, "{}", .{fuzz_seed}) catch @panic("oom");
+
+    const fuzz_run = b.addRunArtifact(fuzz);
+    fuzz_run.addArtifactArg(echo_args);
+    fuzz_run.addArgs(&.{ fuzz_iterations_arg, fuzz_seed_arg });
+    fuzz_run.expectExitCode(0);
+    fuzz_run.skip_foreign_checks = true;
+
+    test_step.dependOn(&fuzz_run.step);
+}
test/standalone/windows_bat_args/echo-args.zig
@@ -0,0 +1,14 @@
+const std = @import("std");
+
+pub fn main() !void {
+    var arena_state = std.heap.ArenaAllocator.init(std.heap.page_allocator);
+    defer arena_state.deinit();
+    const arena = arena_state.allocator();
+
+    const stdout = std.io.getStdOut().writer();
+    var args = try std.process.argsAlloc(arena);
+    for (args[1..], 1..) |arg, i| {
+        try stdout.writeAll(arg);
+        if (i != args.len - 1) try stdout.writeByte('\x00');
+    }
+}
test/standalone/windows_bat_args/fuzz.zig
@@ -0,0 +1,160 @@
+const std = @import("std");
+const builtin = @import("builtin");
+const Allocator = std.mem.Allocator;
+
+pub fn main() anyerror!void {
+    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
+    defer if (gpa.deinit() == .leak) @panic("found memory leaks");
+    const allocator = gpa.allocator();
+
+    var it = try std.process.argsWithAllocator(allocator);
+    defer it.deinit();
+    _ = it.next() orelse unreachable; // skip binary name
+    const child_exe_path = it.next() orelse unreachable;
+
+    const iterations: u64 = iterations: {
+        const arg = it.next() orelse "0";
+        break :iterations try std.fmt.parseUnsigned(u64, arg, 10);
+    };
+
+    var rand_seed = false;
+    const seed: u64 = seed: {
+        const seed_arg = it.next() orelse {
+            rand_seed = true;
+            var buf: [8]u8 = undefined;
+            try std.posix.getrandom(&buf);
+            break :seed std.mem.readInt(u64, &buf, builtin.cpu.arch.endian());
+        };
+        break :seed try std.fmt.parseUnsigned(u64, seed_arg, 10);
+    };
+    var random = std.rand.DefaultPrng.init(seed);
+    const rand = random.random();
+
+    // If the seed was not given via the CLI, then output the
+    // randomly chosen seed so that this run can be reproduced
+    if (rand_seed) {
+        std.debug.print("rand seed: {}\n", .{seed});
+    }
+
+    var tmp = std.testing.tmpDir(.{});
+    defer tmp.cleanup();
+
+    try tmp.dir.setAsCwd();
+    defer tmp.parent_dir.setAsCwd() catch {};
+
+    var buf = try std.ArrayList(u8).initCapacity(allocator, 128);
+    defer buf.deinit();
+    try buf.appendSlice("@echo off\n");
+    try buf.append('"');
+    try buf.appendSlice(child_exe_path);
+    try buf.append('"');
+    const preamble_len = buf.items.len;
+
+    try buf.appendSlice(" %*");
+    try tmp.dir.writeFile("args1.bat", buf.items);
+    buf.shrinkRetainingCapacity(preamble_len);
+
+    try buf.appendSlice(" %1 %2 %3 %4 %5 %6 %7 %8 %9");
+    try tmp.dir.writeFile("args2.bat", buf.items);
+    buf.shrinkRetainingCapacity(preamble_len);
+
+    try buf.appendSlice(" \"%~1\" \"%~2\" \"%~3\" \"%~4\" \"%~5\" \"%~6\" \"%~7\" \"%~8\" \"%~9\"");
+    try tmp.dir.writeFile("args3.bat", buf.items);
+    buf.shrinkRetainingCapacity(preamble_len);
+
+    var i: u64 = 0;
+    while (iterations == 0 or i < iterations) {
+        const rand_arg = try randomArg(allocator, rand);
+        defer allocator.free(rand_arg);
+
+        try testExec(allocator, &.{rand_arg}, null);
+
+        i += 1;
+    }
+}
+
+fn testExec(allocator: std.mem.Allocator, args: []const []const u8, env: ?*std.process.EnvMap) !void {
+    try testExecBat(allocator, "args1.bat", args, env);
+    try testExecBat(allocator, "args2.bat", args, env);
+    try testExecBat(allocator, "args3.bat", args, env);
+}
+
+fn testExecBat(allocator: std.mem.Allocator, bat: []const u8, args: []const []const u8, env: ?*std.process.EnvMap) !void {
+    var argv = try std.ArrayList([]const u8).initCapacity(allocator, 1 + args.len);
+    defer argv.deinit();
+    argv.appendAssumeCapacity(bat);
+    argv.appendSliceAssumeCapacity(args);
+
+    const can_have_trailing_empty_args = std.mem.eql(u8, bat, "args3.bat");
+
+    const result = try std.ChildProcess.run(.{
+        .allocator = allocator,
+        .env_map = env,
+        .argv = argv.items,
+    });
+    defer allocator.free(result.stdout);
+    defer allocator.free(result.stderr);
+
+    try std.testing.expectEqualStrings("", result.stderr);
+    var it = std.mem.splitScalar(u8, result.stdout, '\x00');
+    var i: usize = 0;
+    while (it.next()) |actual_arg| {
+        if (i >= args.len and can_have_trailing_empty_args) {
+            try std.testing.expectEqualStrings("", actual_arg);
+            continue;
+        }
+        const expected_arg = args[i];
+        try std.testing.expectEqualSlices(u8, expected_arg, actual_arg);
+        i += 1;
+    }
+}
+
+fn randomArg(allocator: Allocator, rand: std.rand.Random) ![]const u8 {
+    const Choice = enum {
+        backslash,
+        quote,
+        space,
+        control,
+        printable,
+        surrogate_half,
+        non_ascii,
+    };
+
+    const choices = rand.uintAtMostBiased(u16, 256);
+    var buf = try std.ArrayList(u8).initCapacity(allocator, choices);
+    errdefer buf.deinit();
+
+    var last_codepoint: u21 = 0;
+    for (0..choices) |_| {
+        const choice = rand.enumValue(Choice);
+        const codepoint: u21 = switch (choice) {
+            .backslash => '\\',
+            .quote => '"',
+            .space => ' ',
+            .control => switch (rand.uintAtMostBiased(u8, 0x21)) {
+                // NUL/CR/LF can't roundtrip
+                '\x00', '\r', '\n' => ' ',
+                0x21 => '\x7F',
+                else => |b| b,
+            },
+            .printable => '!' + rand.uintAtMostBiased(u8, '~' - '!'),
+            .surrogate_half => rand.intRangeAtMostBiased(u16, 0xD800, 0xDFFF),
+            .non_ascii => rand.intRangeAtMostBiased(u21, 0x80, 0x10FFFF),
+        };
+        // Ensure that we always return well-formed WTF-8.
+        // Instead of concatenating to ensure well-formed WTF-8,
+        // we just skip encoding the low surrogate.
+        if (std.unicode.isSurrogateCodepoint(last_codepoint) and std.unicode.isSurrogateCodepoint(codepoint)) {
+            if (std.unicode.utf16IsHighSurrogate(@intCast(last_codepoint)) and std.unicode.utf16IsLowSurrogate(@intCast(codepoint))) {
+                continue;
+            }
+        }
+        try buf.ensureUnusedCapacity(4);
+        const unused_slice = buf.unusedCapacitySlice();
+        const len = std.unicode.wtf8Encode(codepoint, unused_slice) catch unreachable;
+        buf.items.len += len;
+        last_codepoint = codepoint;
+    }
+
+    return buf.toOwnedSlice();
+}
test/standalone/windows_bat_args/test.zig
@@ -0,0 +1,132 @@
+const std = @import("std");
+
+pub fn main() anyerror!void {
+    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
+    defer if (gpa.deinit() == .leak) @panic("found memory leaks");
+    const allocator = gpa.allocator();
+
+    var it = try std.process.argsWithAllocator(allocator);
+    defer it.deinit();
+    _ = it.next() orelse unreachable; // skip binary name
+    const child_exe_path = it.next() orelse unreachable;
+
+    var tmp = std.testing.tmpDir(.{});
+    defer tmp.cleanup();
+
+    try tmp.dir.setAsCwd();
+    defer tmp.parent_dir.setAsCwd() catch {};
+
+    var buf = try std.ArrayList(u8).initCapacity(allocator, 128);
+    defer buf.deinit();
+    try buf.appendSlice("@echo off\n");
+    try buf.append('"');
+    try buf.appendSlice(child_exe_path);
+    try buf.append('"');
+    const preamble_len = buf.items.len;
+
+    try buf.appendSlice(" %*");
+    try tmp.dir.writeFile("args1.bat", buf.items);
+    buf.shrinkRetainingCapacity(preamble_len);
+
+    try buf.appendSlice(" %1 %2 %3 %4 %5 %6 %7 %8 %9");
+    try tmp.dir.writeFile("args2.bat", buf.items);
+    buf.shrinkRetainingCapacity(preamble_len);
+
+    try buf.appendSlice(" \"%~1\" \"%~2\" \"%~3\" \"%~4\" \"%~5\" \"%~6\" \"%~7\" \"%~8\" \"%~9\"");
+    try tmp.dir.writeFile("args3.bat", buf.items);
+    buf.shrinkRetainingCapacity(preamble_len);
+
+    // Test cases are from https://github.com/rust-lang/rust/blob/master/tests/ui/std/windows-bat-args.rs
+    try testExecError(error.InvalidBatchScriptArg, allocator, &.{"\x00"});
+    try testExecError(error.InvalidBatchScriptArg, allocator, &.{"\n"});
+    try testExecError(error.InvalidBatchScriptArg, allocator, &.{"\r"});
+    try testExec(allocator, &.{ "a", "b" }, null);
+    try testExec(allocator, &.{ "c is for cat", "d is for dog" }, null);
+    try testExec(allocator, &.{ "\"", " \"" }, null);
+    try testExec(allocator, &.{ "\\", "\\" }, null);
+    try testExec(allocator, &.{">file.txt"}, null);
+    try testExec(allocator, &.{"whoami.exe"}, null);
+    try testExec(allocator, &.{"&a.exe"}, null);
+    try testExec(allocator, &.{"&echo hello "}, null);
+    try testExec(allocator, &.{ "&echo hello", "&whoami", ">file.txt" }, null);
+    try testExec(allocator, &.{"!TMP!"}, null);
+    try testExec(allocator, &.{"key=value"}, null);
+    try testExec(allocator, &.{"\"key=value\""}, null);
+    try testExec(allocator, &.{"key = value"}, null);
+    try testExec(allocator, &.{"key=[\"value\"]"}, null);
+    try testExec(allocator, &.{ "", "a=b" }, null);
+    try testExec(allocator, &.{"key=\"foo bar\""}, null);
+    try testExec(allocator, &.{"key=[\"my_value]"}, null);
+    try testExec(allocator, &.{"key=[\"my_value\",\"other-value\"]"}, null);
+    try testExec(allocator, &.{"key\\=value"}, null);
+    try testExec(allocator, &.{"key=\"&whoami\""}, null);
+    try testExec(allocator, &.{"key=\"value\"=5"}, null);
+    try testExec(allocator, &.{"key=[\">file.txt\"]"}, null);
+    try testExec(allocator, &.{"%hello"}, null);
+    try testExec(allocator, &.{"%PATH%"}, null);
+    try testExec(allocator, &.{"%%cd:~,%"}, null);
+    try testExec(allocator, &.{"%PATH%PATH%"}, null);
+    try testExec(allocator, &.{"\">file.txt"}, null);
+    try testExec(allocator, &.{"abc\"&echo hello"}, null);
+    try testExec(allocator, &.{"123\">file.txt"}, null);
+    try testExec(allocator, &.{"\"&echo hello&whoami.exe"}, null);
+    try testExec(allocator, &.{ "\"hello^\"world\"", "hello &echo oh no >file.txt" }, null);
+    try testExec(allocator, &.{"&whoami.exe"}, null);
+
+    var env = env: {
+        var env = try std.process.getEnvMap(allocator);
+        errdefer env.deinit();
+        // No escaping
+        try env.put("FOO", "123");
+        // Some possible escaping of %FOO% that could be expanded
+        // when escaping cmd.exe meta characters with ^
+        try env.put("FOO^", "123"); // only escaping %
+        try env.put("^F^O^O^", "123"); // escaping every char
+        break :env env;
+    };
+    defer env.deinit();
+    try testExec(allocator, &.{"%FOO%"}, &env);
+
+    // Ensure that none of the `>file.txt`s have caused file.txt to be created
+    try std.testing.expectError(error.FileNotFound, tmp.dir.access("file.txt", .{}));
+}
+
+fn testExecError(err: anyerror, allocator: std.mem.Allocator, args: []const []const u8) !void {
+    return std.testing.expectError(err, testExec(allocator, args, null));
+}
+
+fn testExec(allocator: std.mem.Allocator, args: []const []const u8, env: ?*std.process.EnvMap) !void {
+    try testExecBat(allocator, "args1.bat", args, env);
+    try testExecBat(allocator, "args2.bat", args, env);
+    try testExecBat(allocator, "args3.bat", args, env);
+}
+
+fn testExecBat(allocator: std.mem.Allocator, bat: []const u8, args: []const []const u8, env: ?*std.process.EnvMap) !void {
+    var argv = try std.ArrayList([]const u8).initCapacity(allocator, 1 + args.len);
+    defer argv.deinit();
+    argv.appendAssumeCapacity(bat);
+    argv.appendSliceAssumeCapacity(args);
+
+    const can_have_trailing_empty_args = std.mem.eql(u8, bat, "args3.bat");
+
+    const result = try std.ChildProcess.run(.{
+        .allocator = allocator,
+        .env_map = env,
+        .argv = argv.items,
+    });
+    defer allocator.free(result.stdout);
+    defer allocator.free(result.stderr);
+
+    try std.testing.expectEqualStrings("", result.stderr);
+    var it = std.mem.splitScalar(u8, result.stdout, '\x00');
+    var i: usize = 0;
+    while (it.next()) |actual_arg| {
+        if (i >= args.len and can_have_trailing_empty_args) {
+            try std.testing.expectEqualStrings("", actual_arg);
+            continue;
+        }
+        const expected_arg = args[i];
+        try std.testing.expectEqualStrings(expected_arg, actual_arg);
+        i += 1;
+    }
+}
test/standalone/build.zig.zon
@@ -107,6 +107,9 @@
         .windows_argv = .{
             .path = "windows_argv",
         },
+        .windows_bat_args = .{
+            .path = "windows_bat_args",
+        },
         .self_exe_symlink = .{
             .path = "self_exe_symlink",
         },