Commit 3a3d2187f9

Andrew Kelley <andrew@ziglang.org>
2024-05-28 21:31:10
std.Progress: better Windows support
* Merge a bunch of related state together into TerminalMode. Windows sometimes follows the same path as posix via ansi_escape_codes, sometimes not. * Use a different thread entry point for Windows API but share the same entry point on Windows when the terminal is in ansi_escape_codes mode. * Only clear the terminal when the stderr lock is held. * Don't try to clear the terminal when nothing has been written yet. * Don't try to clear the terminal in IPC mode. * Fix size detection logic bug under error conditions.
1 parent 40afac4
Changed files (1)
lib
lib/std/Progress.zig
@@ -8,19 +8,13 @@ const assert = std.debug.assert;
 const Progress = @This();
 const posix = std.posix;
 const is_big_endian = builtin.cpu.arch.endian() == .big;
+const is_windows = builtin.os.tag == .windows;
 
 /// `null` if the current node (and its children) should
 /// not print on update()
-terminal: ?std.fs.File,
+terminal: std.fs.File,
 
-/// Is this a windows API terminal (note: this is not the same as being run on windows
-/// because other terminals exist like MSYS/git-bash)
-is_windows_terminal: bool,
-/// The output code page of the console (only set if the console is a Windows API terminal)
-console_code_page: if (builtin.os.tag == .windows) windows.UINT else void,
-
-/// Whether the terminal supports ANSI escape codes.
-supports_ansi_escape_codes: bool,
+terminal_mode: TerminalMode,
 
 update_thread: ?std.Thread,
 
@@ -53,6 +47,19 @@ node_freelist: []Node.OptionalIndex,
 node_freelist_first: Node.OptionalIndex,
 node_end_index: u32,
 
+pub const TerminalMode = union(enum) {
+    off,
+    ansi_escape_codes,
+    /// This is not the same as being run on windows because other terminals
+    /// exist like MSYS/git-bash.
+    windows_api: if (is_windows) WindowsApi else void,
+
+    pub const WindowsApi = struct {
+        /// The output code page of the console.
+        code_page: windows.UINT,
+    };
+};
+
 pub const Options = struct {
     /// User-provided buffer with static lifetime.
     ///
@@ -297,10 +304,8 @@ pub const Node = struct {
 };
 
 var global_progress: Progress = .{
-    .terminal = null,
-    .is_windows_terminal = false,
-    .console_code_page = if (builtin.os.tag == .windows) undefined else {},
-    .supports_ansi_escape_codes = false,
+    .terminal = undefined,
+    .terminal_mode = .off,
     .update_thread = null,
     .redraw_event = .{},
     .refresh_rate_ns = undefined,
@@ -376,20 +381,16 @@ pub fn start(options: Options) Node {
                 return .{ .index = .none };
             }
             const stderr = std.io.getStdErr();
+            global_progress.terminal = stderr;
             if (stderr.supportsAnsiEscapeCodes()) {
-                global_progress.terminal = stderr;
-                global_progress.supports_ansi_escape_codes = true;
-            } else if (builtin.os.tag == .windows and stderr.isTty()) {
-                global_progress.is_windows_terminal = true;
-                global_progress.console_code_page = windows.kernel32.GetConsoleOutputCP();
-                global_progress.terminal = stderr;
-            } else if (builtin.os.tag != .windows) {
-                // we are in a "dumb" terminal like in acme or writing to a file
-                global_progress.terminal = stderr;
+                global_progress.terminal_mode = .ansi_escape_codes;
+            } else if (is_windows and stderr.isTty()) {
+                global_progress.terminal_mode = TerminalMode{ .windows_api = .{
+                    .code_page = windows.kernel32.GetConsoleOutputCP(),
+                } };
             }
 
-            const can_clear_terminal = global_progress.supports_ansi_escape_codes or global_progress.is_windows_terminal;
-            if (global_progress.terminal == null or !can_clear_terminal) {
+            if (global_progress.terminal_mode == .off) {
                 return .{ .index = .none };
             }
 
@@ -404,7 +405,11 @@ pub fn start(options: Options) Node {
                 };
             }
 
-            if (std.Thread.spawn(.{}, updateThreadRun, .{})) |thread| {
+            if (switch (global_progress.terminal_mode) {
+                .off => unreachable, // handled a few lines above
+                .ansi_escape_codes => std.Thread.spawn(.{}, updateThreadRun, .{}),
+                .windows_api => if (is_windows) std.Thread.spawn(.{}, windowsApiUpdateThreadRun, .{}) else unreachable,
+            }) |thread| {
                 global_progress.update_thread = thread;
             } else |err| {
                 std.log.warn("unable to spawn thread for printing progress to terminal: {s}", .{@errorName(err)});
@@ -438,13 +443,42 @@ fn updateThreadRun() void {
 
     {
         const resize_flag = wait(global_progress.initial_delay_ns);
+        if (@atomicLoad(bool, &global_progress.done, .seq_cst)) return;
         maybeUpdateSize(resize_flag);
 
+        const buffer = computeRedraw(&serialized_buffer);
+        if (stderr_mutex.tryLock()) {
+            defer stderr_mutex.unlock();
+            write(buffer) catch return;
+        }
+    }
+
+    while (true) {
+        const resize_flag = wait(global_progress.refresh_rate_ns);
+
         if (@atomicLoad(bool, &global_progress.done, .seq_cst)) {
             stderr_mutex.lock();
             defer stderr_mutex.unlock();
-            return clearTerminal();
+            return clearWrittenWithEscapeCodes() catch {};
+        }
+
+        maybeUpdateSize(resize_flag);
+
+        const buffer = computeRedraw(&serialized_buffer);
+        if (stderr_mutex.tryLock()) {
+            defer stderr_mutex.unlock();
+            write(buffer) catch return;
         }
+    }
+}
+
+fn windowsApiUpdateThreadRun() void {
+    var serialized_buffer: Serialized.Buffer = undefined;
+
+    {
+        const resize_flag = wait(global_progress.initial_delay_ns);
+        if (@atomicLoad(bool, &global_progress.done, .seq_cst)) return;
+        maybeUpdateSize(resize_flag);
 
         const buffer = computeRedraw(&serialized_buffer);
         if (stderr_mutex.tryLock()) {
@@ -455,17 +489,19 @@ fn updateThreadRun() void {
 
     while (true) {
         const resize_flag = wait(global_progress.refresh_rate_ns);
-        maybeUpdateSize(resize_flag);
 
         if (@atomicLoad(bool, &global_progress.done, .seq_cst)) {
             stderr_mutex.lock();
             defer stderr_mutex.unlock();
-            return clearTerminal();
+            return clearWrittenWindowsApi() catch {};
         }
 
+        maybeUpdateSize(resize_flag);
+
         const buffer = computeRedraw(&serialized_buffer);
         if (stderr_mutex.tryLock()) {
             defer stderr_mutex.unlock();
+            clearWrittenWindowsApi() catch return;
             write(buffer) catch return;
         }
     }
@@ -476,7 +512,7 @@ fn updateThreadRun() void {
 /// During the lock, any `std.Progress` information is cleared from the terminal.
 pub fn lockStdErr() void {
     stderr_mutex.lock();
-    clearTerminal();
+    clearWrittenWithEscapeCodes() catch {};
 }
 
 pub fn unlockStdErr() void {
@@ -504,7 +540,7 @@ fn ipcThreadRun(fd: posix.fd_t) anyerror!void {
         _ = wait(global_progress.refresh_rate_ns);
 
         if (@atomicLoad(bool, &global_progress.done, .seq_cst))
-            return clearTerminal();
+            return;
 
         const serialized = serialize(&serialized_buffer);
         writeIpc(fd, serialized) catch |err| switch (err) {
@@ -569,41 +605,36 @@ const TreeSymbol = enum {
         var max: usize = 0;
         inline for (@typeInfo(Encoding).Enum.fields) |field| {
             const len = symbol.bytes(@field(Encoding, field.name)).len;
-            if (len > max) max = len;
+            max = @max(max, len);
         }
         return max;
     }
 };
 
-fn appendTreeSymbol(comptime symbol: TreeSymbol, buf: []u8, start_i: usize) usize {
-    if (builtin.os.tag == .windows and global_progress.is_windows_terminal) {
-        const bytes = switch (global_progress.console_code_page) {
-            // Code page 437 is the default code page and contains the box drawing symbols
-            437 => symbol.bytes(.code_page_437),
-            // UTF-8
-            65001 => symbol.bytes(.utf8),
-            // Fall back to ASCII approximation
-            else => symbol.bytes(.ascii),
-        };
-        @memcpy(buf[start_i..][0..bytes.len], bytes);
-        return start_i + bytes.len;
+fn appendTreeSymbol(symbol: TreeSymbol, buf: []u8, start_i: usize) usize {
+    switch (global_progress.terminal_mode) {
+        .off => unreachable,
+        .ansi_escape_codes => {
+            const bytes = symbol.escapeSeq();
+            buf[start_i..][0..bytes.len].* = bytes.*;
+            return start_i + bytes.len;
+        },
+        .windows_api => |windows_api| {
+            const bytes = if (!is_windows) unreachable else switch (windows_api.code_page) {
+                // Code page 437 is the default code page and contains the box drawing symbols
+                437 => symbol.bytes(.code_page_437),
+                // UTF-8
+                65001 => symbol.bytes(.utf8),
+                // Fall back to ASCII approximation
+                else => symbol.bytes(.ascii),
+            };
+            @memcpy(buf[start_i..][0..bytes.len], bytes);
+            return start_i + bytes.len;
+        },
     }
-
-    // Drawing the tree is disabled when ansi escape codes are not supported
-    assert(global_progress.supports_ansi_escape_codes);
-
-    const bytes = symbol.escapeSeq();
-    buf[start_i..][0..bytes.len].* = bytes.*;
-    return start_i + bytes.len;
 }
 
-fn clearTerminal() void {
-    if (builtin.os.tag == .windows and global_progress.is_windows_terminal) {
-        return clearTerminalWindowsApi() catch {
-            global_progress.terminal = null;
-        };
-    }
-
+fn clearWrittenWithEscapeCodes() anyerror!void {
     if (global_progress.written_newline_count == 0) return;
 
     var i: usize = 0;
@@ -618,9 +649,7 @@ fn clearTerminal() void {
     i += finish_sync.len;
 
     global_progress.accumulated_newline_count = 0;
-    write(buf[0..i]) catch {
-        global_progress.terminal = null;
-    };
+    try write(buf[0..i]);
 }
 
 fn computeClear(buf: []u8, start_i: usize) usize {
@@ -645,7 +674,7 @@ fn computeClear(buf: []u8, start_i: usize) usize {
 /// U+25BA or ►
 const windows_api_start_marker = 0x25BA;
 
-fn clearTerminalWindowsApi() error{Unexpected}!void {
+fn clearWrittenWindowsApi() error{Unexpected}!void {
     // This uses a 'marker' strategy. The idea is:
     // - Always write a marker (in this case U+25BA or ►) at the beginning of the progress
     // - Get the current cursor position (at the end of the progress)
@@ -667,7 +696,7 @@ fn clearTerminalWindowsApi() error{Unexpected}!void {
     //   like any of the available attributes are invisible/benign.
     const prev_nl_n = global_progress.written_newline_count;
     if (prev_nl_n > 0) {
-        const handle = (global_progress.terminal orelse return).handle;
+        const handle = global_progress.terminal.handle;
         const screen_area = @as(windows.DWORD, global_progress.cols) * global_progress.rows;
 
         var console_info: windows.CONSOLE_SCREEN_BUFFER_INFO = undefined;
@@ -777,14 +806,14 @@ const SavedMetadata = struct {
     nodes_len: u8,
 
     fn getIpcFd(metadata: SavedMetadata) posix.fd_t {
-        return if (builtin.os.tag == .windows)
+        return if (is_windows)
             @ptrFromInt(@as(usize, metadata.ipc_fd) << 2)
         else
             metadata.ipc_fd;
     }
 
     fn setIpcFd(fd: posix.fd_t) u16 {
-        return @intCast(if (builtin.os.tag == .windows)
+        return @intCast(if (is_windows)
             @shrExact(@intFromPtr(fd), 2)
         else
             fd);
@@ -1019,35 +1048,21 @@ fn computeRedraw(serialized_buffer: *Serialized.Buffer) []u8 {
     var i: usize = 0;
     const buf = global_progress.draw_buffer;
 
-    if (global_progress.supports_ansi_escape_codes) {
-        buf[i..][0..start_sync.len].* = start_sync.*;
-        i += start_sync.len;
-
-        i = computeClear(buf, i);
-    } else if (builtin.os.tag == .windows and global_progress.is_windows_terminal) {
-        clearTerminalWindowsApi() catch {
-            global_progress.terminal = null;
-            return buf[0..0];
-        };
+    buf[i..][0..start_sync.len].* = start_sync.*;
+    i += start_sync.len;
 
-        // Write the marker that we will use to find the beginning of the progress when clearing.
-        // Note: This doesn't have to use WriteConsoleW, but doing so avoids dealing with the code page.
-        var num_chars_written: windows.DWORD = undefined;
-        const handle = (global_progress.terminal orelse return buf[0..0]).handle;
-        if (windows.kernel32.WriteConsoleW(handle, &[_]u16{windows_api_start_marker}, 1, &num_chars_written, null) == 0) {
-            global_progress.terminal = null;
-            return buf[0..0];
-        }
+    switch (global_progress.terminal_mode) {
+        .off => unreachable,
+        .ansi_escape_codes => i = computeClear(buf, i),
+        .windows_api => if (!is_windows) unreachable,
     }
 
     global_progress.accumulated_newline_count = 0;
     const root_node_index: Node.Index = @enumFromInt(0);
     i = computeNode(buf, i, serialized, children, root_node_index);
 
-    if (global_progress.supports_ansi_escape_codes) {
-        buf[i..][0..finish_sync.len].* = finish_sync.*;
-        i += finish_sync.len;
-    }
+    buf[i..][0..finish_sync.len].* = finish_sync.*;
+    i += finish_sync.len;
 
     return buf[0..i];
 }
@@ -1075,15 +1090,15 @@ fn computePrefix(
         buf[i..][0..prefix.len].* = prefix.*;
         i += prefix.len;
     } else {
-        const upper_bound_len = TreeSymbol.line.maxByteLen() + line_upper_bound_len;
+        const upper_bound_len = comptime (TreeSymbol.line.maxByteLen() + line_upper_bound_len);
         if (i + upper_bound_len > buf.len) return buf.len;
         i = appendTreeSymbol(.line, buf, i);
     }
     return i;
 }
 
-const line_upper_bound_len = @max(TreeSymbol.tee.maxByteLen(), TreeSymbol.langle.maxByteLen()) + "[4294967296/4294967296] ".len +
-    Node.max_name_len + finish_sync.len;
+const line_upper_bound_len = @max(TreeSymbol.tee.maxByteLen(), TreeSymbol.langle.maxByteLen()) +
+    "[4294967296/4294967296] ".len + Node.max_name_len + finish_sync.len;
 
 fn computeNode(
     buf: []u8,
@@ -1157,8 +1172,7 @@ fn withinRowLimit(p: *Progress) bool {
 }
 
 fn write(buf: []const u8) anyerror!void {
-    const tty = global_progress.terminal orelse return;
-    try tty.writeAll(buf);
+    try global_progress.terminal.writeAll(buf);
     global_progress.written_newline_count = global_progress.accumulated_newline_count;
 }
 
@@ -1218,23 +1232,23 @@ fn writeIpc(fd: posix.fd_t, serialized: Serialized) error{BrokenPipe}!void {
 fn maybeUpdateSize(resize_flag: bool) void {
     if (!resize_flag) return;
 
-    const fd = (global_progress.terminal orelse return).handle;
+    const fd = global_progress.terminal.handle;
 
-    if (builtin.os.tag == .windows) {
+    if (is_windows) {
         var info: windows.CONSOLE_SCREEN_BUFFER_INFO = undefined;
 
-        if (windows.kernel32.GetConsoleScreenBufferInfo(fd, &info) == windows.FALSE) {
+        if (windows.kernel32.GetConsoleScreenBufferInfo(fd, &info) != windows.FALSE) {
+            // In the old Windows console, dwSize.Y is the line count of the
+            // entire scrollback buffer, so we use this instead so that we
+            // always get the size of the screen.
+            const screen_height = info.srWindow.Bottom - info.srWindow.Top;
+            global_progress.rows = @intCast(screen_height);
+            global_progress.cols = @intCast(info.dwSize.X);
+        } else {
             std.log.debug("failed to determine terminal size; using conservative guess 80x25", .{});
             global_progress.rows = 25;
             global_progress.cols = 80;
         }
-
-        // In the old Windows console, dwSize.Y is the line count of the entire
-        // scrollback buffer, so we use this instead so that we always get the
-        // size of the screen.
-        const screen_height = info.srWindow.Bottom - info.srWindow.Top;
-        global_progress.rows = @intCast(screen_height);
-        global_progress.cols = @intCast(info.dwSize.X);
     } else {
         var winsize: posix.winsize = .{
             .ws_row = 0,