Commit da007f318b

Alexandros Naskos <alex_naskos@hotmail.com>
2020-11-17 00:08:04
Implement std.fs.Watch on Windows Use unmanaged containers in std.fs.Watch
1 parent 5112ab8
Changed files (4)
lib/std/fs/watch.zig
@@ -3,7 +3,7 @@
 // This file is part of [zig](https://ziglang.org/), which is MIT licensed.
 // The MIT license requires this copyright notice to be included in all copies
 // and substantial portions of the software.
-const std = @import("../std.zig");
+const std = @import("std");
 const builtin = @import("builtin");
 const event = std.event;
 const assert = std.debug.assert;
@@ -24,14 +24,6 @@ const WatchEventId = enum {
     Delete,
 };
 
-fn eqlString(a: []const u16, b: []const u16) bool {
-    return mem.eql(u16, a, b);
-}
-
-fn hashString(s: []const u16) u32 {
-    return @truncate(u32, std.hash.Wyhash.hash(0, mem.sliceAsBytes(s)));
-}
-
 const WatchEventError = error{
     UserResourceLimitReached,
     SystemResources,
@@ -69,21 +61,15 @@ pub fn Watch(comptime V: type) type {
         const WindowsOsData = struct {
             table_lock: event.Lock,
             dir_table: DirTable,
-            all_putters: std.atomic.Queue(Put),
-            ref_count: std.atomic.Int(usize),
-
-            const Put = struct {
-                putter: anyframe,
-                cancelled: bool = false,
-            };
+            cancelled: bool = false,
 
-            const DirTable = std.StringHashMap(*Dir);
-            const FileTable = std.HashMap([]const u16, V, hashString, eqlString);
+            const DirTable = std.StringHashMapUnmanaged(*Dir);
+            const FileTable = std.StringHashMapUnmanaged(V);
 
             const Dir = struct {
                 putter_frame: @Frame(windowsDirReader),
                 file_table: FileTable,
-                table_lock: event.Lock,
+                dir_handle: os.windows.HANDLE,
             };
         };
 
@@ -94,8 +80,8 @@ pub fn Watch(comptime V: type) type {
             table_lock: event.Lock,
             cancelled: bool = false,
 
-            const WdTable = std.AutoHashMap(i32, Dir);
-            const FileTable = std.StringHashMap(V);
+            const WdTable = std.AutoHashMapUnmanaged(i32, Dir);
+            const FileTable = std.StringHashMapUnmanaged(V);
 
             const Dir = struct {
                 dirname: []const u8,
@@ -148,10 +134,9 @@ pub fn Watch(comptime V: type) type {
                         .os_data = OsData{
                             .table_lock = event.Lock{},
                             .dir_table = OsData.DirTable.init(allocator),
-                            .ref_count = std.atomic.Int(usize).init(1),
-                            .all_putters = std.atomic.Queue(WindowsOsData.Put).init(),
                         },
                     };
+
                     var buf = try allocator.alloc(Event.Error!Event, event_buf_count);
                     self.channel.init(buf);
                     return self;
@@ -160,12 +145,15 @@ pub fn Watch(comptime V: type) type {
                 .macos, .freebsd, .netbsd, .dragonfly, .openbsd => {
                     self.* = Self{
                         .allocator = allocator,
-                        .channel = channel,
+                        .channel = undefined,
                         .os_data = OsData{
                             .table_lock = event.Lock.init(),
                             .file_table = OsData.FileTable.init(allocator),
                         },
                     };
+
+                    var buf = try allocator.alloc(Event.Error!Event, event_buf_count);
+                    self.channel.init(buf);
                     return self;
                 },
                 else => @compileError("Unsupported OS"),
@@ -206,35 +194,38 @@ pub fn Watch(comptime V: type) type {
                     self.allocator.destroy(self);
                 },
                 .windows => {
-                    while (self.os_data.all_putters.get()) |putter_node| {
-                        putter_node.cancelled = true;
-                        await putter_node.frame;
+                    self.os_data.cancelled = true;
+                    var dir_it = self.os_data.dir_table.iterator();
+                    while (dir_it.next()) |dir_entry| {
+                        if (windows.kernel32.CancelIoEx(dir_entry.value.dir_handle, null) != 0) {
+                            // We canceled the pending ReadDirectoryChangesW operation, but our
+                            // frame is still suspending, now waiting indefinitely.
+                            // Thus, it is safe to resume it ourslves
+                            resume dir_entry.value.putter_frame;
+                        } else {
+                            std.debug.assert(windows.kernel32.GetLastError() == .NOT_FOUND);
+                            // We are at another suspend point, we can await safely for the
+                            // function to exit the loop
+                            await dir_entry.value.putter_frame;
+                        }
+
+                        self.allocator.free(dir_entry.key);
+                        var file_it = dir_entry.value.file_table.iterator();
+                        while (file_it.next()) |file_entry| {
+                            self.allocator.free(file_entry.key);
+                        }
+                        dir_entry.value.file_table.deinit(self.allocator);
+                        self.allocator.destroy(dir_entry.value);
                     }
-                    self.deref();
+                    self.os_data.dir_table.deinit(self.allocator);
+                    self.allocator.free(self.channel.buffer_nodes);
+                    self.channel.deinit();
+                    self.allocator.destroy(self);
                 },
                 else => @compileError("Unsupported OS"),
             }
         }
 
-        fn ref(self: *Self) void {
-            _ = self.os_data.ref_count.incr();
-        }
-
-        fn deref(self: *Self) void {
-            if (self.os_data.ref_count.decr() == 1) {
-                self.os_data.table_lock.deinit();
-                var it = self.os_data.dir_table.iterator();
-                while (it.next()) |entry| {
-                    self.allocator.free(entry.key);
-                    self.allocator.destroy(entry.value);
-                }
-                self.os_data.dir_table.deinit();
-                self.channel.deinit();
-                self.allocator.destroy(self.channel.buffer_nodes);
-                self.allocator.destroy(self);
-            }
-        }
-
         pub fn addFile(self: *Self, file_path: []const u8, value: V) !?V {
             switch (builtin.os.tag) {
                 .macos, .freebsd, .netbsd, .dragonfly, .openbsd => return addFileKEvent(self, file_path, value),
@@ -342,7 +333,7 @@ pub fn Watch(comptime V: type) type {
             const held = self.os_data.table_lock.acquire();
             defer held.release();
 
-            const gop = try self.os_data.wd_table.getOrPut(wd);
+            const gop = try self.os_data.wd_table.getOrPut(self.allocator, wd);
             if (!gop.found_existing) {
                 gop.entry.value = OsData.Dir{
                     .dirname = try self.allocator.dupe(u8, dirname),
@@ -351,7 +342,7 @@ pub fn Watch(comptime V: type) type {
             }
 
             const dir = &gop.entry.value;
-            const file_table_gop = try dir.file_table.getOrPut(basename);
+            const file_table_gop = try dir.file_table.getOrPut(self.allocator, basename);
             if (file_table_gop.found_existing) {
                 const prev_value = file_table_gop.entry.value;
                 file_table_gop.entry.value = value;
@@ -365,89 +356,67 @@ pub fn Watch(comptime V: type) type {
 
         fn addFileWindows(self: *Self, file_path: []const u8, value: V) !?V {
             // TODO we might need to convert dirname and basename to canonical file paths ("short"?)
-            const dirname = try self.allocator.dupe(u8, std.fs.path.dirname(file_path) orelse ".");
-            var dirname_consumed = false;
-            defer if (!dirname_consumed) self.allocator.free(dirname);
-
-            const dirname_utf16le = try std.unicode.utf8ToUtf16LeWithNull(self.allocator, dirname);
-            defer self.allocator.free(dirname_utf16le);
+            const dirname = std.fs.path.dirname(file_path) orelse ".";
+            var dirname_path_space: windows.PathSpace = undefined;
+            dirname_path_space.len = try std.unicode.utf8ToUtf16Le(&dirname_path_space.data, dirname);
+            dirname_path_space.data[dirname_path_space.len] = 0;
 
-            // TODO https://github.com/ziglang/zig/issues/265
             const basename = std.fs.path.basename(file_path);
-            const basename_utf16le_null = try std.unicode.utf8ToUtf16LeWithNull(self.allocator, basename);
-            var basename_utf16le_null_consumed = false;
-            defer if (!basename_utf16le_null_consumed) self.allocator.free(basename_utf16le_null);
-            const basename_utf16le_no_null = basename_utf16le_null[0 .. basename_utf16le_null.len - 1];
-
-            const dir_handle = try windows.OpenFile(dirname_utf16le, .{
-                .dir = std.fs.cwd().fd,
-                .access_mask = windows.FILE_LIST_DIRECTORY,
-                .creation = windows.FILE_OPEN,
-                .io_mode = .blocking,
-                .open_dir = true,
-            });
-            var dir_handle_consumed = false;
-            defer if (!dir_handle_consumed) windows.CloseHandle(dir_handle);
+            var basename_path_space: windows.PathSpace = undefined;
+            basename_path_space.len = try std.unicode.utf8ToUtf16Le(&basename_path_space.data, basename);
+            basename_path_space.data[basename_path_space.len] = 0;
 
             const held = self.os_data.table_lock.acquire();
             defer held.release();
 
-            const gop = try self.os_data.dir_table.getOrPut(dirname);
+            const gop = try self.os_data.dir_table.getOrPut(self.allocator, dirname);
             if (gop.found_existing) {
-                const dir = gop.kv.value;
-                const held_dir_lock = dir.table_lock.acquire();
-                defer held_dir_lock.release();
+                const dir = gop.entry.value;
 
-                const file_gop = try dir.file_table.getOrPut(basename_utf16le_no_null);
+                const file_gop = try dir.file_table.getOrPut(self.allocator, basename);
                 if (file_gop.found_existing) {
-                    const prev_value = file_gop.kv.value;
-                    file_gop.kv.value = value;
+                    const prev_value = file_gop.entry.value;
+                    file_gop.entry.value = value;
                     return prev_value;
                 } else {
-                    file_gop.kv.value = value;
-                    basename_utf16le_null_consumed = true;
+                    file_gop.entry.value = value;
+                    file_gop.entry.key = try self.allocator.dupe(u8, basename);
                     return null;
                 }
             } else {
                 errdefer _ = self.os_data.dir_table.remove(dirname);
+                const dir_handle = try windows.OpenFile(dirname_path_space.span(), .{
+                    .dir = std.fs.cwd().fd,
+                    .access_mask = windows.FILE_LIST_DIRECTORY,
+                    .creation = windows.FILE_OPEN,
+                    .io_mode = .evented,
+                    .open_dir = true,
+                });
+                errdefer windows.CloseHandle(dir_handle);
+
                 const dir = try self.allocator.create(OsData.Dir);
                 errdefer self.allocator.destroy(dir);
 
+                gop.entry.key = try self.allocator.dupe(u8, dirname);
+                errdefer self.allocator.free(gop.entry.key);
+
                 dir.* = OsData.Dir{
                     .file_table = OsData.FileTable.init(self.allocator),
-                    .table_lock = event.Lock.init(),
                     .putter_frame = undefined,
+                    .dir_handle = dir_handle,
                 };
-                gop.kv.value = dir;
-                assert((try dir.file_table.put(basename_utf16le_no_null, value)) == null);
-                basename_utf16le_null_consumed = true;
-
-                dir.putter_frame = async self.windowsDirReader(dir_handle, dir);
-                dir_handle_consumed = true;
-
-                dirname_consumed = true;
-
+                gop.entry.value = dir;
+                try dir.file_table.put(self.allocator, try self.allocator.dupe(u8, basename), value);
+                dir.putter_frame = async self.windowsDirReader(dir, gop.entry.key);
                 return null;
             }
         }
 
-        fn windowsDirReader(self: *Self, dir_handle: windows.HANDLE, dir: *OsData.Dir) void {
-            self.ref();
-            defer self.deref();
-
-            defer os.close(dir_handle);
-
-            var putter_node = std.atomic.Queue(anyframe).Node{
-                .data = .{ .putter = @frame() },
-                .prev = null,
-                .next = null,
-            };
-            self.os_data.all_putters.put(&putter_node);
-            defer _ = self.os_data.all_putters.remove(&putter_node);
-
+        fn windowsDirReader(self: *Self, dir: *OsData.Dir, dirname: []const u8) void {
+            defer os.close(dir.dir_handle);
             var resume_node = Loop.ResumeNode.Basic{
                 .base = Loop.ResumeNode{
-                    .id = Loop.ResumeNode.Id.Basic,
+                    .id = .Basic,
                     .handle = @frame(),
                     .overlapped = windows.OVERLAPPED{
                         .Internal = 0,
@@ -458,81 +427,75 @@ pub fn Watch(comptime V: type) type {
                     },
                 },
             };
-            var event_buf: [4096]u8 align(@alignOf(windows.FILE_NOTIFY_INFORMATION)) = undefined;
 
-            // TODO handle this error not in the channel but in the setup
-            _ = windows.CreateIoCompletionPort(
-                dir_handle,
-                global_event_loop.os_data.io_port,
-                undefined,
-                undefined,
-            ) catch |err| {
-                self.channel.put(err);
-                return;
-            };
+            var event_buf: [4096]u8 align(@alignOf(windows.FILE_NOTIFY_INFORMATION)) = undefined;
 
-            while (!putter_node.data.cancelled) {
-                {
-                    // TODO only 1 beginOneEvent for the whole function
-                    global_event_loop.beginOneEvent();
-                    errdefer global_event_loop.finishOneEvent();
-                    errdefer {
-                        _ = windows.kernel32.CancelIoEx(dir_handle, &resume_node.base.overlapped);
-                    }
-                    suspend {
-                        _ = windows.kernel32.ReadDirectoryChangesW(
-                            dir_handle,
-                            &event_buf,
-                            @intCast(windows.DWORD, event_buf.len),
-                            windows.FALSE, // watch subtree
-                            windows.FILE_NOTIFY_CHANGE_FILE_NAME | windows.FILE_NOTIFY_CHANGE_DIR_NAME |
-                                windows.FILE_NOTIFY_CHANGE_ATTRIBUTES | windows.FILE_NOTIFY_CHANGE_SIZE |
-                                windows.FILE_NOTIFY_CHANGE_LAST_WRITE | windows.FILE_NOTIFY_CHANGE_LAST_ACCESS |
-                                windows.FILE_NOTIFY_CHANGE_CREATION | windows.FILE_NOTIFY_CHANGE_SECURITY,
-                            null, // number of bytes transferred (unused for async)
-                            &resume_node.base.overlapped,
-                            null, // completion routine - unused because we use IOCP
-                        );
-                    }
+            global_event_loop.beginOneEvent();
+            defer global_event_loop.finishOneEvent();
+
+            while (!self.os_data.cancelled) main_loop: {
+                suspend {
+                    _ = windows.kernel32.ReadDirectoryChangesW(
+                        dir.dir_handle,
+                        &event_buf,
+                        event_buf.len,
+                        windows.FALSE, // watch subtree
+                        windows.FILE_NOTIFY_CHANGE_FILE_NAME | windows.FILE_NOTIFY_CHANGE_DIR_NAME |
+                            windows.FILE_NOTIFY_CHANGE_ATTRIBUTES | windows.FILE_NOTIFY_CHANGE_SIZE |
+                            windows.FILE_NOTIFY_CHANGE_LAST_WRITE | windows.FILE_NOTIFY_CHANGE_LAST_ACCESS |
+                            windows.FILE_NOTIFY_CHANGE_CREATION | windows.FILE_NOTIFY_CHANGE_SECURITY,
+                        null, // number of bytes transferred (unused for async)
+                        &resume_node.base.overlapped,
+                        null, // completion routine - unused because we use IOCP
+                    );
                 }
+
                 var bytes_transferred: windows.DWORD = undefined;
-                if (windows.kernel32.GetOverlappedResult(dir_handle, &resume_node.base.overlapped, &bytes_transferred, windows.FALSE) == 0) {
-                    const err = switch (windows.kernel32.GetLastError()) {
+                if (windows.kernel32.GetOverlappedResult(
+                    dir.dir_handle,
+                    &resume_node.base.overlapped,
+                    &bytes_transferred,
+                    windows.FALSE,
+                ) == 0) {
+                    const potential_error = windows.kernel32.GetLastError();
+                    const err = switch (potential_error) {
+                        .OPERATION_ABORTED, .IO_INCOMPLETE => err_blk: {
+                            if (self.os_data.cancelled)
+                                break :main_loop
+                            else
+                                break :err_blk windows.unexpectedError(potential_error);
+                        },
                         else => |err| windows.unexpectedError(err),
                     };
                     self.channel.put(err);
                 } else {
-                    // can't use @bytesToSlice because of the special variable length name field
-                    var ptr = event_buf[0..].ptr;
+                    var ptr: [*]u8 = &event_buf;
                     const end_ptr = ptr + bytes_transferred;
-                    var ev: *windows.FILE_NOTIFY_INFORMATION = undefined;
-                    while (@ptrToInt(ptr) < @ptrToInt(end_ptr)) : (ptr += ev.NextEntryOffset) {
-                        ev = @ptrCast(*windows.FILE_NOTIFY_INFORMATION, ptr);
+                    while (@ptrToInt(ptr) < @ptrToInt(end_ptr)) {
+                        const ev = @ptrCast(*const windows.FILE_NOTIFY_INFORMATION, ptr);
                         const emit = switch (ev.Action) {
                             windows.FILE_ACTION_REMOVED => WatchEventId.Delete,
-                            windows.FILE_ACTION_MODIFIED => WatchEventId.CloseWrite,
+                            windows.FILE_ACTION_MODIFIED => .CloseWrite,
                             else => null,
                         };
                         if (emit) |id| {
-                            const basename_utf16le = ([*]u16)(&ev.FileName)[0 .. ev.FileNameLength / 2];
-                            const user_value = blk: {
-                                const held = dir.table_lock.acquire();
-                                defer held.release();
-
-                                if (dir.file_table.get(basename_utf16le)) |entry| {
-                                    break :blk entry.value;
-                                } else {
-                                    break :blk null;
-                                }
-                            };
-                            if (user_value) |v| {
+                            const basename_ptr = @ptrCast([*]u16, ptr + @sizeOf(windows.FILE_NOTIFY_INFORMATION));
+                            const basename_utf16le = basename_ptr[0 .. ev.FileNameLength / 2];
+                            var basename_data: [std.fs.MAX_PATH_BYTES]u8 = undefined;
+                            const basename = basename_data[0 .. std.unicode.utf16leToUtf8(&basename_data, basename_utf16le) catch unreachable];
+
+                            if (dir.file_table.getEntry(basename)) |entry| {
                                 self.channel.put(Event{
                                     .id = id,
-                                    .data = v,
+                                    .data = entry.value,
+                                    .dirname = dirname,
+                                    .basename = entry.key,
                                 });
                             }
                         }
+
                         if (ev.NextEntryOffset == 0) break;
+                        ptr = @alignCast(@alignOf(windows.FILE_NOTIFY_INFORMATION), ptr + ev.NextEntryOffset);
                     }
                 }
             }
@@ -554,8 +517,21 @@ pub fn Watch(comptime V: type) type {
                     }
                     return null;
                 },
+                .windows => {
+                    const dirname = std.fs.path.dirname(file_path) orelse ".";
+                    const basename = std.fs.path.basename(file_path);
+
+                    const held = self.os_data.table_lock.acquire();
+                    defer held.release();
+
+                    const dir = self.os_data.dir_table.get(dirname) orelse return null;
+                    if (dir.file_table.remove(basename)) |file_entry| {
+                        self.allocator.free(file_entry.key);
+                        return file_entry.value;
+                    }
+                    return null;
+                },
                 .macos, .freebsd, .netbsd, .dragonfly, .openbsd => @panic("TODO"),
-                .windows => return @panic("TODO"),
                 else => @compileError("Unsupported OS"),
             }
         }
@@ -565,7 +541,7 @@ pub fn Watch(comptime V: type) type {
 
             defer {
                 std.debug.assert(self.os_data.wd_table.count() == 0);
-                self.os_data.wd_table.deinit();
+                self.os_data.wd_table.deinit(self.allocator);
                 os.close(self.os_data.inotify_fd);
                 self.allocator.free(self.channel.buffer_nodes);
                 self.channel.deinit();
@@ -585,9 +561,6 @@ pub fn Watch(comptime V: type) type {
                         const basename_ptr = ptr + @sizeOf(os.linux.inotify_event);
                         const basename = std.mem.span(@ptrCast([*:0]u8, basename_ptr));
 
-                        const held = self.os_data.table_lock.acquire();
-                        defer held.release();
-
                         const dir = &self.os_data.wd_table.get(ev.wd).?;
                         if (dir.file_table.getEntry(basename)) |file_value| {
                             self.channel.put(Event{
@@ -607,17 +580,14 @@ pub fn Watch(comptime V: type) type {
                                 self.allocator.free(file_entry.key);
                             }
                             self.allocator.free(wd_entry.value.dirname);
-                            wd_entry.value.file_table.deinit();
+                            wd_entry.value.file_table.deinit(self.allocator);
                         }
                     } else if (ev.mask & os.linux.IN_DELETE == os.linux.IN_DELETE) {
                         // File or directory was removed or deleted
                         const basename_ptr = ptr + @sizeOf(os.linux.inotify_event);
                         const basename = std.mem.span(@ptrCast([*:0]u8, basename_ptr));
 
-                        const held = self.os_data.table_lock.acquire();
-                        defer held.release();
                         const dir = &self.os_data.wd_table.get(ev.wd).?;
-
                         if (dir.file_table.getEntry(basename)) |file_value| {
                             self.channel.put(Event{
                                 .id = .Delete,
lib/std/os/windows/bits.zig
@@ -813,7 +813,8 @@ pub const FILE_NOTIFY_INFORMATION = extern struct {
     NextEntryOffset: DWORD,
     Action: DWORD,
     FileNameLength: DWORD,
-    FileName: [1]WCHAR,
+    // Flexible array member
+    // FileName: [1]WCHAR,
 };
 
 pub const FILE_ACTION_ADDED = 0x00000001;
lib/std/os/windows/kernel32.zig
@@ -8,7 +8,7 @@ usingnamespace @import("bits.zig");
 pub extern "kernel32" fn AddVectoredExceptionHandler(First: c_ulong, Handler: ?VECTORED_EXCEPTION_HANDLER) callconv(WINAPI) ?*c_void;
 pub extern "kernel32" fn RemoveVectoredExceptionHandler(Handle: HANDLE) callconv(WINAPI) c_ulong;
 
-pub extern "kernel32" fn CancelIoEx(hFile: HANDLE, lpOverlapped: LPOVERLAPPED) callconv(WINAPI) BOOL;
+pub extern "kernel32" fn CancelIoEx(hFile: HANDLE, lpOverlapped: ?LPOVERLAPPED) callconv(WINAPI) BOOL;
 
 pub extern "kernel32" fn CloseHandle(hObject: HANDLE) callconv(WINAPI) BOOL;
 
lib/std/os/windows.zig
@@ -109,7 +109,12 @@ pub fn OpenFile(sub_path_w: []const u16, options: OpenFileOptions) OpenError!HAN
             0,
         );
         switch (rc) {
-            .SUCCESS => return result,
+            .SUCCESS => {
+                if (options.io_mode == .evented) {
+                    _ = CreateIoCompletionPort(result, std.event.Loop.instance.?.os_data.io_port, undefined, undefined) catch undefined;
+                }
+                return result;
+            },
             .OBJECT_NAME_INVALID => unreachable,
             .OBJECT_NAME_NOT_FOUND => return error.FileNotFound,
             .OBJECT_PATH_NOT_FOUND => return error.FileNotFound,
@@ -418,8 +423,6 @@ pub fn ReadFile(in_hFile: HANDLE, buffer: []u8, offset: ?u64, io_mode: std.io.Mo
                 },
             },
         };
-        // TODO only call create io completion port once per fd
-        _ = CreateIoCompletionPort(in_hFile, loop.os_data.io_port, undefined, undefined) catch undefined;
         loop.beginOneEvent();
         suspend {
             // TODO handle buffer bigger than DWORD can hold
@@ -500,8 +503,6 @@ pub fn WriteFile(
                 },
             },
         };
-        // TODO only call create io completion port once per fd
-        _ = CreateIoCompletionPort(handle, loop.os_data.io_port, undefined, undefined) catch undefined;
         loop.beginOneEvent();
         suspend {
             const adjusted_len = math.cast(DWORD, bytes.len) catch maxInt(DWORD);