Commit 64c5f4979e

Rohlem <rohlemF@gmail.com>
2020-12-17 16:37:40
std.os.windows.GetFinalPathNameByHandle: replace kernel32 by ntdll call
Removes the call to kernel32.GetFinalPathNameByHandleW in favor of NtQueryObject, which means we can reuse the other codepath's logic for DOS naming.
1 parent 964bbcd
Changed files (1)
lib
lib/std/os/windows.zig
@@ -1001,19 +1001,13 @@ test "QueryObjectName" {
     _ = try QueryObjectName(file.handle, out_buffer[0..result_path.len]);
 }
 
-pub const GetFinalPathNameByHandleError = error {
-        BadPathName,
-        FileNotFound,
-        NameTooLong,
-        Unexpected,
-    }
-    || if((comptime builtin.os.tag != .windows) or (targetVersionIsAtLeast(WindowsVersion.win10_rs4) == true))
-        error {}
-    else
-        error {
-            AccessDenied,
-            SystemResources,
-        };
+pub const GetFinalPathNameByHandleError = error{
+    AccessDenied,
+    BadPathName,
+    FileNotFound,
+    NameTooLong,
+    Unexpected,
+};
 
 /// Specifies how to format volume path in the result of `GetFinalPathNameByHandle`.
 /// Defaults to DOS volume names.
@@ -1036,75 +1030,64 @@ pub fn GetFinalPathNameByHandle(
     fmt: GetFinalPathNameByHandleFormat,
     out_buffer: []u16,
 ) GetFinalPathNameByHandleError![]u16 {
+    var path_buffer: [std.math.max(@sizeOf(FILE_NAME_INFORMATION), @sizeOf(OBJECT_NAME_INFORMATION)) + PATH_MAX_WIDE * 2]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined;
+    var volume_buffer: [@sizeOf(FILE_NAME_INFORMATION) + MAX_PATH]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined; // MAX_PATH bytes should be enough since it's Windows-defined name
 
-    var path_buffer: [@sizeOf(FILE_NAME_INFORMATION) + PATH_MAX_WIDE * 2]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined;
-
+    var file_name_u16: []const u16 = undefined;
+    var volume_name_u16: []const u16 = undefined;
     if ((comptime (targetVersionIsAtLeast(WindowsVersion.win10_rs4) != true)) //need explicit comptime, because error returns affect return type
-       and !runtimeVersionIsAtLeast(WindowsVersion.win10_rs4)) {
-        // TODO: directly replace/emulate QueryInformationFile of .FileNormalizedNameInformation
-        // with ntdll instead of calling into kernel32
-        // (probably using some less-powerful query and looping over path segments)
-        const flags: DWORD = FILE_NAME_NORMALIZED | switch(fmt.volume_name) {
-            .Dos => @as(DWORD, VOLUME_NAME_DOS),
-            .Nt => @as(DWORD, VOLUME_NAME_NT),
+        and !runtimeVersionIsAtLeast(WindowsVersion.win10_rs4))
+    {
+        const final_path = QueryObjectName(hFile, std.mem.bytesAsSlice(u16, path_buffer[0..])) catch |err| return switch (err) {
+            error.InvalidHandle => error.FileNotFound, //close enough?
+            else => |e| e,
         };
-        const wide_path_buffer = std.mem.bytesAsSlice(u16, path_buffer[0..]);
-        const rc = kernel32.GetFinalPathNameByHandleW(hFile, wide_path_buffer.ptr, @intCast(u32, wide_path_buffer.len), flags);
-        if (rc == 0) {
-            switch (kernel32.GetLastError()) {
-                .FILE_NOT_FOUND => return error.FileNotFound,
-                .PATH_NOT_FOUND => return error.FileNotFound,
-                .NOT_ENOUGH_MEMORY => return error.SystemResources,
-                .FILENAME_EXCED_RANGE => return error.NameTooLong,
-                .ACCESS_DENIED => return error.AccessDenied, //can happen in SMB sub-queries for parent path segments
-                .INVALID_PARAMETER => unreachable,
-                else => |err| return unexpectedError(err),
+
+        if (fmt.volume_name == .Nt) {
+            if (out_buffer.len < final_path.len) {
+                return error.NameTooLong;
             }
+            std.mem.copy(u16, out_buffer[0..], final_path[0..]);
+            return final_path; //we can directly return the slice we received
         }
 
-        //in case of failure, rc == length of string INCLUDING null terminator,
-        if (rc > wide_path_buffer.len) return error.NameTooLong;
-        //in case of success, rc == length of string EXCLUDING null terminator
-        const result_slice = switch(fmt.volume_name) {
-            .Dos => blk: {
-                const expected_prefix = [_]u16{'\\', '\\', '?', '\\'};
-                if (!std.mem.eql(u16, expected_prefix[0..], wide_path_buffer[0..expected_prefix.len])) {
-                    return error.BadPathName;
-                }
-                break :blk wide_path_buffer[expected_prefix.len..rc:0];
-            },
-            //no prefix here
-            .Nt => wide_path_buffer[0..rc:0],
-        };
-        if(result_slice.len > out_buffer.len) return error.NameTooLong;
-        std.mem.copy(u16, out_buffer[0..], result_slice);
-        return out_buffer[0..result_slice.len];
-    }
-
-    // Get normalized path; doesn't include volume name though.
-    try QueryInformationFile(hFile, .FileNormalizedNameInformation, path_buffer[0..]);
-
-    // Get NT volume name.
-    var volume_buffer: [@sizeOf(FILE_NAME_INFORMATION) + MAX_PATH]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined; // MAX_PATH bytes should be enough since it's Windows-defined name
-    try QueryInformationFile(hFile, .FileVolumeNameInformation, volume_buffer[0..]);
+        //otherwise we need to parse the string for volume path for the .Dos logic below to work
+        const expected_prefix = std.unicode.utf8ToUtf16LeStringLiteral("\\Device\\");
+        if (!std.mem.eql(u16, expected_prefix, final_path[0..expected_prefix.len])) {
+            //TODO find out if this can occur, and if we need to handle it differently
+            //(i.e. how to determine the end of a volume name)
+            return error.BadPathName;
+        }
+        const index = std.mem.indexOfPos(u16, final_path, expected_prefix.len, &[_]u16{'\\'}) orelse unreachable;
+        volume_name_u16 = final_path[0..index];
+        file_name_u16 = final_path[index..];
 
-    const file_name = @ptrCast(*const FILE_NAME_INFORMATION, &path_buffer[0]);
-    const file_name_u16 = @ptrCast([*]const u16, &file_name.FileName[0])[0 .. file_name.FileNameLength / 2];
+        //fallthrough for fmt.volume_name != .Nt
+    } else {
+        // Get normalized path; doesn't include volume name though.
+        try QueryInformationFile(hFile, .FileNormalizedNameInformation, path_buffer[0..]);
+        const file_name = @ptrCast(*const FILE_NAME_INFORMATION, &path_buffer[0]);
+        file_name_u16 = @ptrCast([*]const u16, &file_name.FileName[0])[0..@divExact(file_name.FileNameLength, 2)];
 
-    const volume_name = @ptrCast(*const FILE_NAME_INFORMATION, &volume_buffer[0]);
+        // Get NT volume name.
+        try QueryInformationFile(hFile, .FileVolumeNameInformation, volume_buffer[0..]);
+        const volume_name_info = @ptrCast(*const FILE_NAME_INFORMATION, &volume_buffer[0]);
+        volume_name_u16 = @ptrCast([*]const u16, &volume_name_info.FileName[0])[0..@divExact(volume_name_info.FileNameLength, 2)];
 
-    switch (fmt.volume_name) {
-        .Nt => {
+        if (fmt.volume_name == .Nt) {
             // Nothing to do, we simply copy the bytes to the user-provided buffer.
-            const volume_name_u16 = @ptrCast([*]const u16, &volume_name.FileName[0])[0 .. volume_name.FileNameLength / 2];
-
             if (out_buffer.len < volume_name_u16.len + file_name_u16.len) return error.NameTooLong;
 
             std.mem.copy(u16, out_buffer[0..], volume_name_u16);
             std.mem.copy(u16, out_buffer[volume_name_u16.len..], file_name_u16);
 
             return out_buffer[0 .. volume_name_u16.len + file_name_u16.len];
-        },
+        }
+        //fallthrough for fmt.volume_name != .Nt
+    }
+
+    switch (fmt.volume_name) {
+        .Nt => unreachable, //handled above
         .Dos => {
             // Get DOS volume name. DOS volume names are actually symbolic link objects to the
             // actual NT volume. For example:
@@ -1138,8 +1121,8 @@ pub fn GetFinalPathNameByHandle(
 
             var input_struct = @ptrCast(*MOUNTMGR_MOUNT_POINT, &input_buf[0]);
             input_struct.DeviceNameOffset = @sizeOf(MOUNTMGR_MOUNT_POINT);
-            input_struct.DeviceNameLength = @intCast(USHORT, volume_name.FileNameLength);
-            @memcpy(input_buf[@sizeOf(MOUNTMGR_MOUNT_POINT)..], @ptrCast([*]const u8, &volume_name.FileName[0]), volume_name.FileNameLength);
+            input_struct.DeviceNameLength = @intCast(USHORT, volume_name_u16.len * 2);
+            @memcpy(input_buf[@sizeOf(MOUNTMGR_MOUNT_POINT)..], @ptrCast([*]const u8, volume_name_u16.ptr), volume_name_u16.len * 2);
 
             DeviceIoControl(mgmt_handle, IOCTL_MOUNTMGR_QUERY_POINTS, input_buf[0..], output_buf[0..]) catch |err| switch (err) {
                 error.AccessDenied => unreachable,