Commit cb20503990

Rohlem <rohlemF@gmail.com>
2020-12-23 02:39:02
std.os.windows.GetFinalPathNameByHandle: address non-structural review comments
1 parent 64c5f49
Changed files (5)
lib/std/os/windows/bits.zig
@@ -66,7 +66,6 @@ pub const SIZE_T = usize;
 pub const TCHAR = if (UNICODE) WCHAR else u8;
 pub const UINT = c_uint;
 pub const ULONG_PTR = usize;
-pub const PULONG = *ULONG;
 pub const LONG_PTR = isize;
 pub const DWORD_PTR = ULONG_PTR;
 pub const UNICODE = false;
lib/std/os/windows/ntdll.zig
@@ -119,5 +119,5 @@ pub extern "NtDll" fn NtQueryObject(
     ObjectInformationClass: OBJECT_INFORMATION_CLASS,
     ObjectInformation: PVOID,
     ObjectInformationLength: ULONG,
-    ReturnLength: ?PULONG,
+    ReturnLength: ?*ULONG,
 ) callconv(WINAPI) NTSTATUS;
lib/std/os/windows.zig
@@ -29,7 +29,8 @@ pub const gdi32 = @import("windows/gdi32.zig");
 pub usingnamespace @import("windows/bits.zig");
 
 //version detection
-usingnamespace std.zig.system.windows;
+const version = std.zig.system.windows;
+const WindowsVersion = version.WindowsVersion;
 
 pub const self_process_handle = @intToPtr(HANDLE, maxInt(usize));
 
@@ -963,42 +964,44 @@ pub fn QueryObjectName(
     var full_buffer: [@sizeOf(OBJECT_NAME_INFORMATION) + PATH_MAX_WIDE * 2]u8 align(@alignOf(OBJECT_NAME_INFORMATION)) = undefined;
     var info = @ptrCast(*OBJECT_NAME_INFORMATION, &full_buffer);
     //buffer size is specified in bytes
-    const full_buffer_length = @intCast(ULONG, @sizeOf(OBJECT_NAME_INFORMATION) + std.math.min(PATH_MAX_WIDE, (out_buffer.len + 1) * 2));
     //last argument would return the length required for full_buffer, not exposed here
-    const rc = ntdll.NtQueryObject(handle, .ObjectNameInformation, full_buffer[0..], full_buffer_length, null);
-    return switch (rc) {
-        .SUCCESS => if (@ptrCast(?[*]WCHAR, info.Name.Buffer)) |buffer| blk: {
+    const rc = ntdll.NtQueryObject(handle, .ObjectNameInformation, &full_buffer, full_buffer.len, null);
+    switch (rc) {
+        .SUCCESS => {
+            // info.Name.Buffer from ObQueryNameString is documented to be null (and MaximumLength == 0)
+            // if the object was "unnamed", not sure if this can happen for file handles
+            if (info.Name.MaximumLength == 0) return error.Unexpected;
             //resulting string length is specified in bytes
             const path_length_unterminated = @divExact(info.Name.Length, 2);
             if (out_buffer.len < path_length_unterminated) {
                 return error.NameTooLong;
             }
-            std.mem.copy(WCHAR, out_buffer[0..path_length_unterminated], buffer[0..path_length_unterminated :0]);
-            break :blk out_buffer[0..path_length_unterminated];
-        } else error.Unexpected,
-        .ACCESS_DENIED => error.AccessDenied,
-        .INVALID_HANDLE => error.InvalidHandle,
-        .BUFFER_OVERFLOW, .BUFFER_TOO_SMALL => error.NameTooLong,
+            mem.copy(WCHAR, out_buffer[0..path_length_unterminated], info.Name.Buffer[0..path_length_unterminated]);
+            return out_buffer[0..path_length_unterminated];
+        },
+        .ACCESS_DENIED => return error.AccessDenied,
+        .INVALID_HANDLE => return error.InvalidHandle,
+        .BUFFER_OVERFLOW, .BUFFER_TOO_SMALL => return error.NameTooLong,
         //name_buffer.len >= @sizeOf(OBJECT_NAME_INFORMATION) holds statically
         .INFO_LENGTH_MISMATCH => unreachable,
-        else => |e| unexpectedStatus(e),
-    };
+        else => |e| return unexpectedStatus(e),
+    }
 }
 test "QueryObjectName" {
     if (comptime builtin.os.tag != .windows)
         return;
 
     //any file will do; canonicalization works on NTFS junctions and symlinks, hardlinks remain separate paths.
-    const file = try std.fs.openSelfExe(.{});
-    defer file.close();
-    //make this large enough for the test runner exe path
-    var out_buffer align(16) = std.mem.zeroes([1 << 10]u16);
+    var tmp = std.testing.tmpDir(.{});
+    defer tmp.cleanup();
+    const handle = tmp.dir.fd;
+    var out_buffer: [PATH_MAX_WIDE]u16 = undefined;
 
-    var result_path = try QueryObjectName(file.handle, out_buffer[0..]);
+    var result_path = try QueryObjectName(handle, &out_buffer);
     //insufficient size
-    std.testing.expectError(error.NameTooLong, QueryObjectName(file.handle, out_buffer[0 .. result_path.len - 1]));
+    std.testing.expectError(error.NameTooLong, QueryObjectName(handle, out_buffer[0 .. result_path.len - 1]));
     //exactly-sufficient size
-    _ = try QueryObjectName(file.handle, out_buffer[0..result_path.len]);
+    _ = try QueryObjectName(handle, out_buffer[0..result_path.len]);
 }
 
 pub const GetFinalPathNameByHandleError = error{
@@ -1030,16 +1033,16 @@ pub fn GetFinalPathNameByHandle(
     fmt: GetFinalPathNameByHandleFormat,
     out_buffer: []u16,
 ) GetFinalPathNameByHandleError![]u16 {
-    var path_buffer: [std.math.max(@sizeOf(FILE_NAME_INFORMATION), @sizeOf(OBJECT_NAME_INFORMATION)) + PATH_MAX_WIDE * 2]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined;
+    var path_buffer: [math.max(@sizeOf(FILE_NAME_INFORMATION), @sizeOf(OBJECT_NAME_INFORMATION)) + PATH_MAX_WIDE * 2]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined;
     var volume_buffer: [@sizeOf(FILE_NAME_INFORMATION) + MAX_PATH]u8 align(@alignOf(FILE_NAME_INFORMATION)) = undefined; // MAX_PATH bytes should be enough since it's Windows-defined name
 
     var file_name_u16: []const u16 = undefined;
     var volume_name_u16: []const u16 = undefined;
-    if ((comptime (targetVersionIsAtLeast(WindowsVersion.win10_rs4) != true)) //need explicit comptime, because error returns affect return type
-        and !runtimeVersionIsAtLeast(WindowsVersion.win10_rs4))
-    {
-        const final_path = QueryObjectName(hFile, std.mem.bytesAsSlice(u16, path_buffer[0..])) catch |err| return switch (err) {
-            error.InvalidHandle => error.FileNotFound, //close enough?
+    if ((comptime (std.builtin.os.version_range.windows.isAtLeast(WindowsVersion.win10_rs4) != true)) and !version.detectRuntimeVersion().isAtLeast(WindowsVersion.win10_rs4)) {
+        const final_path = QueryObjectName(hFile, mem.bytesAsSlice(u16, &path_buffer)) catch |err| return switch (err) {
+            // we assume InvalidHandle is close enough to FileNotFound in semantics
+            // to not further complicate the error set
+            error.InvalidHandle => error.FileNotFound,
             else => |e| e,
         };
 
@@ -1047,39 +1050,40 @@ pub fn GetFinalPathNameByHandle(
             if (out_buffer.len < final_path.len) {
                 return error.NameTooLong;
             }
-            std.mem.copy(u16, out_buffer[0..], final_path[0..]);
-            return final_path; //we can directly return the slice we received
+            mem.copy(u16, out_buffer, final_path);
+            return out_buffer[0..final_path.len];
         }
 
         //otherwise we need to parse the string for volume path for the .Dos logic below to work
         const expected_prefix = std.unicode.utf8ToUtf16LeStringLiteral("\\Device\\");
-        if (!std.mem.eql(u16, expected_prefix, final_path[0..expected_prefix.len])) {
-            //TODO find out if this can occur, and if we need to handle it differently
-            //(i.e. how to determine the end of a volume name)
-            return error.BadPathName;
-        }
-        const index = std.mem.indexOfPos(u16, final_path, expected_prefix.len, &[_]u16{'\\'}) orelse unreachable;
+
+        // TODO find out if a path can start with something besides `\Device\<volume name>`,
+        // and if we need to handle it differently
+        // (i.e. how to determine the start and end of the volume name in that case)
+        if (!mem.eql(u16, expected_prefix, final_path[0..expected_prefix.len])) return error.Unexpected;
+
+        const index = mem.indexOfPos(u16, final_path, expected_prefix.len, &[_]u16{'\\'}) orelse unreachable;
         volume_name_u16 = final_path[0..index];
         file_name_u16 = final_path[index..];
 
         //fallthrough for fmt.volume_name != .Nt
     } else {
         // Get normalized path; doesn't include volume name though.
-        try QueryInformationFile(hFile, .FileNormalizedNameInformation, path_buffer[0..]);
-        const file_name = @ptrCast(*const FILE_NAME_INFORMATION, &path_buffer[0]);
-        file_name_u16 = @ptrCast([*]const u16, &file_name.FileName[0])[0..@divExact(file_name.FileNameLength, 2)];
+        try QueryInformationFile(hFile, .FileNormalizedNameInformation, &path_buffer);
+        const file_name = @ptrCast(*const FILE_NAME_INFORMATION, &path_buffer);
+        file_name_u16 = @ptrCast([*]const u16, &file_name.FileName)[0..@divExact(file_name.FileNameLength, 2)];
 
         // Get NT volume name.
-        try QueryInformationFile(hFile, .FileVolumeNameInformation, volume_buffer[0..]);
-        const volume_name_info = @ptrCast(*const FILE_NAME_INFORMATION, &volume_buffer[0]);
-        volume_name_u16 = @ptrCast([*]const u16, &volume_name_info.FileName[0])[0..@divExact(volume_name_info.FileNameLength, 2)];
+        try QueryInformationFile(hFile, .FileVolumeNameInformation, &volume_buffer);
+        const volume_name_info = @ptrCast(*const FILE_NAME_INFORMATION, &volume_buffer);
+        volume_name_u16 = @ptrCast([*]const u16, &volume_name_info.FileName)[0..@divExact(volume_name_info.FileNameLength, 2)];
 
         if (fmt.volume_name == .Nt) {
             // Nothing to do, we simply copy the bytes to the user-provided buffer.
             if (out_buffer.len < volume_name_u16.len + file_name_u16.len) return error.NameTooLong;
 
-            std.mem.copy(u16, out_buffer[0..], volume_name_u16);
-            std.mem.copy(u16, out_buffer[volume_name_u16.len..], file_name_u16);
+            mem.copy(u16, out_buffer, volume_name_u16);
+            mem.copy(u16, out_buffer[volume_name_u16.len..], file_name_u16);
 
             return out_buffer[0 .. volume_name_u16.len + file_name_u16.len];
         }
@@ -1124,7 +1128,7 @@ pub fn GetFinalPathNameByHandle(
             input_struct.DeviceNameLength = @intCast(USHORT, volume_name_u16.len * 2);
             @memcpy(input_buf[@sizeOf(MOUNTMGR_MOUNT_POINT)..], @ptrCast([*]const u8, volume_name_u16.ptr), volume_name_u16.len * 2);
 
-            DeviceIoControl(mgmt_handle, IOCTL_MOUNTMGR_QUERY_POINTS, input_buf[0..], output_buf[0..]) catch |err| switch (err) {
+            DeviceIoControl(mgmt_handle, IOCTL_MOUNTMGR_QUERY_POINTS, &input_buf, &output_buf) catch |err| switch (err) {
                 error.AccessDenied => unreachable,
                 else => |e| return e,
             };
@@ -1144,22 +1148,20 @@ pub fn GetFinalPathNameByHandle(
 
                 // Look for `\DosDevices\` prefix. We don't really care if there are more than one symlinks
                 // with traditional DOS drive letters, so pick the first one available.
-                const prefix_u8 = "\\DosDevices\\";
-                var prefix_buf_u16: [prefix_u8.len]u16 = undefined;
-                const prefix_len_u16 = std.unicode.utf8ToUtf16Le(prefix_buf_u16[0..], prefix_u8[0..]) catch unreachable;
-                const prefix = prefix_buf_u16[0..prefix_len_u16];
+                var prefix_buf = std.unicode.utf8ToUtf16LeStringLiteral("\\DosDevices\\");
+                const prefix = prefix_buf[0..prefix_buf.len];
 
-                if (std.mem.startsWith(u16, symlink, prefix)) {
+                if (mem.startsWith(u16, symlink, prefix)) {
                     const drive_letter = symlink[prefix.len..];
 
                     if (out_buffer.len < drive_letter.len + file_name_u16.len) return error.NameTooLong;
 
-                    std.mem.copy(u16, out_buffer[0..], drive_letter);
-                    std.mem.copy(u16, out_buffer[drive_letter.len..], file_name_u16);
+                    mem.copy(u16, out_buffer, drive_letter);
+                    mem.copy(u16, out_buffer[drive_letter.len..], file_name_u16);
                     const total_len = drive_letter.len + file_name_u16.len;
 
                     // Validate that DOS does not contain any spurious nul bytes.
-                    if (std.mem.indexOfScalar(u16, out_buffer[0..total_len], 0)) |_| {
+                    if (mem.indexOfScalar(u16, out_buffer[0..total_len], 0)) |_| {
                         return error.BadPathName;
                     }
 
@@ -1179,15 +1181,14 @@ test "GetFinalPathNameByHandle" {
         return;
 
     //any file will do
-    const file = try std.fs.openSelfExe(.{});
-    defer file.close();
-    const handle = file.handle;
-    //make this large enough for the test runner exe path
-    var buffer = std.mem.zeroes([1 << 10]u16);
+    var tmp = std.testing.tmpDir(.{});
+    defer tmp.cleanup();
+    const handle = tmp.dir.fd;
+    var buffer: [PATH_MAX_WIDE]u16 = undefined;
 
     //check with sufficient size
-    const nt_length = (try GetFinalPathNameByHandle(handle, .{ .volume_name = .Nt }, buffer[0..])).len;
-    const dos_length = (try GetFinalPathNameByHandle(handle, .{ .volume_name = .Dos }, buffer[0..])).len;
+    const nt_length = (try GetFinalPathNameByHandle(handle, .{ .volume_name = .Nt }, &buffer)).len;
+    const dos_length = (try GetFinalPathNameByHandle(handle, .{ .volume_name = .Dos }, &buffer)).len;
 
     //check with insufficient size
     std.testing.expectError(error.NameTooLong, GetFinalPathNameByHandle(handle, .{ .volume_name = .Nt }, buffer[0 .. nt_length - 1]));
lib/std/zig/system/windows.zig
@@ -43,19 +43,3 @@ pub fn detectRuntimeVersion() WindowsVersion {
 
     return @intToEnum(WindowsVersion, version);
 }
-
-/// Returns whether the target os versions are uniformly at least as new as the argument:
-/// true/false if this holds for the entire target range, null if it only holds for some.
-pub fn targetVersionIsAtLeast(requested_version: WindowsVersion) ?bool {
-    const requested = @enumToInt(requested_version);
-    const version_range = std.builtin.os.version_range.windows;
-    const target_min = @enumToInt(version_range.min);
-    const target_max = @enumToInt(version_range.max);
-    return if (target_max < requested) false else if (target_min >= requested) true else null;
-}
-
-/// Returns whether the runtime os version is at least as new as the argument.
-pub fn runtimeVersionIsAtLeast(requested_version: WindowsVersion) bool {
-    return targetVersionIsAtLeast(requested_version) orelse
-        (@enumToInt(detectRuntimeVersion()) >= @enumToInt(requested_version));
-}
lib/std/target.zig
@@ -95,7 +95,7 @@ pub const Target = struct {
             win7 = 0x06010000,
             win8 = 0x06020000,
             win8_1 = 0x06030000,
-            win10 = 0x0A000000,
+            win10 = 0x0A000000, //aka win10_th1
             win10_th2 = 0x0A000001,
             win10_rs1 = 0x0A000002,
             win10_rs2 = 0x0A000003,
@@ -103,26 +103,35 @@ pub const Target = struct {
             win10_rs4 = 0x0A000005,
             win10_rs5 = 0x0A000006,
             win10_19h1 = 0x0A000007,
-            win10_20h1 = 0x0A000008,
+            win10_vb = 0x0A000008, //aka win10_19h2
+            win10_mn = 0x0A000009, //aka win10_20h1
+            win10_fe = 0x0A00000A, //aka win10_20h2
             _,
 
             /// Latest Windows version that the Zig Standard Library is aware of
-            pub const latest = WindowsVersion.win10_20h1;
+            pub const latest = WindowsVersion.win10_fe;
 
             /// Compared against build numbers reported by the runtime to distinguish win10 versions,
             /// where 0x0A000000 + index corresponds to the WindowsVersion u32 value.
             pub const known_win10_build_numbers = [_]u32{
-                10240, //win10
+                10240, //win10 aka win10_th1
                 10586, //win10_th2
-                14393, //win_rs1
-                15063, //win_rs2
-                16299, //win_rs3
-                17134, //win_rs4
-                17763, //win_rs5
-                18362, //win_19h1
-                19041, //win_20h1
+                14393, //win10_rs1
+                15063, //win10_rs2
+                16299, //win10_rs3
+                17134, //win10_rs4
+                17763, //win10_rs5
+                18362, //win10_19h1
+                18363, //win10_vb aka win10_19h2
+                19041, //win10_mn aka win10_20h1
+                19042, //win10_fe aka win10_20h2
             };
 
+            /// Returns whether the first version `self` is newer (greater) than or equal to the second version `ver`.
+            pub fn isAtLeast(self: WindowsVersion, ver: WindowsVersion) bool {
+                return @enumToInt(self) >= @enumToInt(ver);
+            }
+
             pub const Range = struct {
                 min: WindowsVersion,
                 max: WindowsVersion,