Commit c87f79c957

Ryan Liptak <squeek502@hotmail.com>
2022-01-15 13:56:32
os.getenvW: Fix case-insensitivity for Unicode env var names
Windows does Unicode-aware case-insensitivity comparisons for environment variable names. Before, os.getenvW was only doing ASCII case-insensitivity. We can take advantage of RtlEqualUnicodeString in NtDll to get the proper Unicode case insensitivity.
1 parent 8841a71
Changed files (2)
lib
std
os
windows
lib/std/os/windows/ntdll.zig
@@ -223,6 +223,12 @@ pub extern "ntdll" fn RtlWaitOnAddress(
     Timeout: ?*const LARGE_INTEGER,
 ) callconv(WINAPI) NTSTATUS;
 
+pub extern "ntdll" fn RtlEqualUnicodeString(
+    String1: *const UNICODE_STRING,
+    String2: *const UNICODE_STRING,
+    CaseInSensitive: BOOLEAN,
+) callconv(WINAPI) BOOLEAN;
+
 pub extern "ntdll" fn NtLockFile(
     FileHandle: HANDLE,
     Event: ?HANDLE,
lib/std/os.zig
@@ -1715,15 +1715,13 @@ pub fn getenvZ(key: [*:0]const u8) ?[]const u8 {
 
 /// Windows-only. Get an environment variable with a null-terminated, WTF-16 encoded name.
 /// See also `getenv`.
-/// This function first attempts a case-sensitive lookup. If no match is found, and `key`
-/// is ASCII, then it attempts a second case-insensitive lookup.
+/// This function performs a Unicode-aware case-insensitive lookup using RtlEqualUnicodeString.
 pub fn getenvW(key: [*:0]const u16) ?[:0]const u16 {
     if (builtin.os.tag != .windows) {
         @compileError("std.os.getenvW is a Windows-only API");
     }
     const key_slice = mem.sliceTo(key, 0);
     const ptr = windows.peb().ProcessParameters.Environment;
-    var ascii_match: ?[:0]const u16 = null;
     var i: usize = 0;
     while (ptr[i] != 0) {
         const key_start = i;
@@ -1737,22 +1735,25 @@ pub fn getenvW(key: [*:0]const u16) ?[:0]const u16 {
         while (ptr[i] != 0) : (i += 1) {}
         const this_value = ptr[value_start..i :0];
 
-        if (mem.eql(u16, key_slice, this_key)) return this_value;
-
-        ascii_check: {
-            if (ascii_match != null) break :ascii_check;
-            if (key_slice.len != this_key.len) break :ascii_check;
-            for (key_slice) |a_c, key_index| {
-                const a = math.cast(u8, a_c) catch break :ascii_check;
-                const b = math.cast(u8, this_key[key_index]) catch break :ascii_check;
-                if (std.ascii.toLower(a) != std.ascii.toLower(b)) break :ascii_check;
-            }
-            ascii_match = this_value;
+        const key_string_bytes = @intCast(u16, key_slice.len * 2);
+        const key_string = windows.UNICODE_STRING{
+            .Length = key_string_bytes,
+            .MaximumLength = key_string_bytes,
+            .Buffer = @intToPtr([*]u16, @ptrToInt(key)),
+        };
+        const this_key_string_bytes = @intCast(u16, this_key.len * 2);
+        const this_key_string = windows.UNICODE_STRING{
+            .Length = this_key_string_bytes,
+            .MaximumLength = this_key_string_bytes,
+            .Buffer = this_key.ptr,
+        };
+        if (windows.ntdll.RtlEqualUnicodeString(&key_string, &this_key_string, windows.TRUE) == windows.TRUE) {
+            return this_value;
         }
 
         i += 1; // skip over null byte
     }
-    return ascii_match;
+    return null;
 }
 
 pub const GetCwdError = error{