Commit 0bde55e881

viri <hi@viri.moe>
2022-01-16 00:40:45
std.Thread(windows): use NT internals for name fns
1 parent 12d6bce
Changed files (3)
lib/std/os/windows/kernel32.zig
@@ -400,6 +400,3 @@ pub extern "kernel32" fn SleepConditionVariableSRW(
 pub extern "kernel32" fn TryAcquireSRWLockExclusive(s: *SRWLOCK) callconv(WINAPI) BOOLEAN;
 pub extern "kernel32" fn AcquireSRWLockExclusive(s: *SRWLOCK) callconv(WINAPI) void;
 pub extern "kernel32" fn ReleaseSRWLockExclusive(s: *SRWLOCK) callconv(WINAPI) void;
-
-pub extern "kernel32" fn SetThreadDescription(hThread: HANDLE, lpThreadDescription: LPCWSTR) callconv(WINAPI) HRESULT;
-pub extern "kernel32" fn GetThreadDescription(hThread: HANDLE, ppszThreadDescription: *LPWSTR) callconv(WINAPI) HRESULT;
lib/std/os/windows.zig
@@ -2029,21 +2029,6 @@ pub fn unexpectedStatus(status: NTSTATUS) std.os.UnexpectedError {
     return error.Unexpected;
 }
 
-pub fn SetThreadDescription(hThread: HANDLE, lpThreadDescription: LPCWSTR) !void {
-    if (kernel32.SetThreadDescription(hThread, lpThreadDescription) == 0) {
-        switch (kernel32.GetLastError()) {
-            else => |err| return unexpectedError(err),
-        }
-    }
-}
-pub fn GetThreadDescription(hThread: HANDLE, ppszThreadDescription: *LPWSTR) !void {
-    if (kernel32.GetThreadDescription(hThread, ppszThreadDescription) == 0) {
-        switch (kernel32.GetLastError()) {
-            else => |err| return unexpectedError(err),
-        }
-    }
-}
-
 pub const Win32Error = @import("windows/win32error.zig").Win32Error;
 pub const NTSTATUS = @import("windows/ntstatus.zig").NTSTATUS;
 pub const LANG = @import("windows/lang.zig");
lib/std/Thread.zig
@@ -4,6 +4,7 @@
 
 const std = @import("std.zig");
 const builtin = @import("builtin");
+const math = std.math;
 const os = std.os;
 const assert = std.debug.assert;
 const target = builtin.target;
@@ -85,20 +86,28 @@ pub fn setName(self: Thread, name: []const u8) SetNameError!void {
             try file.writer().writeAll(name);
             return;
         },
-        .windows => if (target.os.isAtLeast(.windows, .win10_rs1)) |res| {
-            // SetThreadDescription is only available since version 1607, which is 10.0.14393.795
-            // See https://en.wikipedia.org/wiki/Microsoft_Windows_SDK
-            if (!res) return error.Unsupported;
-
-            var name_buf_w: [max_name_len:0]u16 = undefined;
-            const length = try std.unicode.utf8ToUtf16Le(&name_buf_w, name);
-            name_buf_w[length] = 0;
+        .windows => {
+            var buf: [max_name_len]u16 = undefined;
+            const len = try std.unicode.utf8ToUtf16Le(&buf, name);
+            const byte_len = math.cast(c_ushort, len * 2) catch return error.NameTooLong;
+
+            // Note: NT allocates its own copy, no use-after-free here.
+            const unicode_string = os.windows.UNICODE_STRING{
+                .Length = byte_len,
+                .MaximumLength = byte_len,
+                .Buffer = &buf,
+            };
 
-            try os.windows.SetThreadDescription(
+            switch (os.windows.ntdll.NtSetInformationThread(
                 self.getHandle(),
-                @ptrCast(os.windows.LPWSTR, &name_buf_w),
-            );
-            return;
+                .ThreadNameInformation,
+                &unicode_string,
+                @sizeOf(os.windows.UNICODE_STRING),
+            )) {
+                .SUCCESS => return,
+                .NOT_IMPLEMENTED => return error.Unsupported,
+                else => |err| return os.windows.unexpectedStatus(err),
+            }
         },
         .macos, .ios, .watchos, .tvos => if (use_pthreads) {
             // There doesn't seem to be a way to set the name for an arbitrary thread, only the current one.
@@ -188,18 +197,25 @@ pub fn getName(self: Thread, buffer_ptr: *[max_name_len:0]u8) GetNameError!?[]co
             // musl doesn't provide pthread_getname_np and there's no way to retrieve the thread id of an arbitrary thread.
             return error.Unsupported;
         },
-        .windows => if (target.os.isAtLeast(.windows, .win10_rs1)) |res| {
-            // GetThreadDescription is only available since version 1607, which is 10.0.14393.795
-            // See https://en.wikipedia.org/wiki/Microsoft_Windows_SDK
-            if (!res) return error.Unsupported;
-
-            var name_w: os.windows.LPWSTR = undefined;
-            try os.windows.GetThreadDescription(self.getHandle(), &name_w);
-            defer os.windows.LocalFree(name_w);
+        .windows => {
+            const buf_capacity = @sizeOf(os.windows.UNICODE_STRING) + (@sizeOf(u16) * max_name_len);
+            var buf: [buf_capacity]u8 align(@alignOf(os.windows.UNICODE_STRING)) = undefined;
 
-            const data_len = try std.unicode.utf16leToUtf8(buffer, std.mem.sliceTo(name_w, 0));
-
-            return if (data_len >= 1) buffer[0..data_len] else null;
+            switch (os.windows.ntdll.NtQueryInformationThread(
+                self.getHandle(),
+                .ThreadNameInformation,
+                &buf,
+                buf_capacity,
+                null,
+            )) {
+                .SUCCESS => {
+                    const string = @ptrCast(*const os.windows.UNICODE_STRING, &buf);
+                    const len = try std.unicode.utf16leToUtf8(buffer, string.Buffer[0 .. string.Length / 2]);
+                    return if (len > 0) buffer[0..len] else null;
+                },
+                .NOT_IMPLEMENTED => return error.Unsupported,
+                else => |err| return os.windows.unexpectedStatus(err),
+            }
         },
         .macos, .ios, .watchos, .tvos => if (use_pthreads) {
             const err = std.c.pthread_getname_np(self.getHandle(), buffer.ptr, max_name_len + 1);