Commit b219feb3f1

Andrew Kelley <superjoe30@gmail.com>
2018-08-09 22:48:44
initial windows implementation of std.event.fs.Watch
1 parent c63ec98
Changed files (4)
src-self-hosted
std
src-self-hosted/compilation.zig
@@ -288,6 +288,7 @@ pub const Compilation = struct {
         InvalidDarwinVersionString,
         UnsupportedLinkArchitecture,
         UserResourceLimitReached,
+        InvalidUtf8,
     };
 
     pub const Event = union(enum) {
std/event/fs.zig
@@ -681,6 +681,7 @@ pub const WatchEventError = error{
     UserResourceLimitReached,
     SystemResources,
     AccessDenied,
+    Unexpected, // TODO remove this possibility
 };
 
 pub fn Watch(comptime V: type) type {
@@ -699,27 +700,48 @@ pub fn Watch(comptime V: type) type {
                     value_ptr: *V,
                 };
             },
-            builtin.Os.linux => struct {
+
+            builtin.Os.linux => LinuxOsData,
+            builtin.Os.windows => WindowsOsData,
+
+            else => @compileError("Unsupported OS"),
+        };
+
+        const WindowsOsData = struct {
+            table_lock: event.Lock,
+            dir_table: DirTable,
+            all_putters: std.atomic.Queue(promise),
+            ref_count: std.atomic.Int(usize),
+
+            const DirTable = std.AutoHashMap([]const u8, *Dir);
+            const FileTable = std.AutoHashMap([]const u16, V);
+
+            const Dir = struct {
                 putter: promise,
-                inotify_fd: i32,
-                wd_table: WdTable,
+                file_table: FileTable,
                 table_lock: event.Lock,
+            };
+        };
 
-                const FileTable = std.AutoHashMap([]const u8, V);
-            },
-            else => @compileError("Unsupported OS"),
+        const LinuxOsData = struct {
+            putter: promise,
+            inotify_fd: i32,
+            wd_table: WdTable,
+            table_lock: event.Lock,
+
+            const WdTable = std.AutoHashMap(i32, Dir);
+            const FileTable = std.AutoHashMap([]const u8, V);
+
+            const Dir = struct {
+                dirname: []const u8,
+                file_table: FileTable,
+            };
         };
 
-        const WdTable = std.AutoHashMap(i32, Dir);
         const FileToHandle = std.AutoHashMap([]const u8, promise);
 
         const Self = this;
 
-        const Dir = struct {
-            dirname: []const u8,
-            file_table: OsData.FileTable,
-        };
-
         pub const Event = struct {
             id: Id,
             data: V,
@@ -741,6 +763,22 @@ pub fn Watch(comptime V: type) type {
                     _ = try async<loop.allocator> linuxEventPutter(inotify_fd, channel, &result);
                     return result;
                 },
+
+                builtin.Os.windows => {
+                    const self = try loop.allocator.createOne(Self);
+                    errdefer loop.allocator.destroy(self);
+                    self.* = Self{
+                        .channel = channel,
+                        .os_data = OsData{
+                            .table_lock = event.Lock.init(loop),
+                            .dir_table = OsData.DirTable.init(loop.allocator),
+                            .ref_count = std.atomic.Int(usize).init(1),
+                            .all_putters = std.atomic.Queue(promise).init(),
+                        },
+                    };
+                    return self;
+                },
+
                 builtin.Os.macosx => {
                     const self = try loop.allocator.createOne(Self);
                     errdefer loop.allocator.destroy(self);
@@ -758,9 +796,11 @@ pub fn Watch(comptime V: type) type {
             }
         }
 
+        /// All addFile calls and removeFile calls must have completed.
         pub fn destroy(self: *Self) void {
             switch (builtin.os) {
                 builtin.Os.macosx => {
+                    // TODO we need to cancel the coroutines before destroying the lock
                     self.os_data.table_lock.deinit();
                     var it = self.os_data.file_table.iterator();
                     while (it.next()) |entry| {
@@ -770,14 +810,41 @@ pub fn Watch(comptime V: type) type {
                     self.channel.destroy();
                 },
                 builtin.Os.linux => cancel self.os_data.putter,
+                builtin.Os.windows => {
+                    while (self.os_data.all_putters.get()) |putter_node| {
+                        cancel putter_node.data;
+                    }
+                    self.deref();
+                },
                 else => @compileError("Unsupported OS"),
             }
         }
 
+        fn ref(self: *Self) void {
+            _ = self.os_data.ref_count.incr();
+        }
+
+        fn deref(self: *Self) void {
+            if (self.os_data.ref_count.decr() == 1) {
+                const allocator = self.channel.loop.allocator;
+                self.os_data.table_lock.deinit();
+                var it = self.os_data.dir_table.iterator();
+                while (it.next()) |entry| {
+                    allocator.free(entry.key);
+                    // TODO why does freeing this memory crash the test?
+                    //allocator.destroy(entry.value);
+                }
+                self.os_data.dir_table.deinit();
+                self.channel.destroy();
+                allocator.destroy(self);
+            }
+        }
+
         pub async fn addFile(self: *Self, file_path: []const u8, value: V) !?V {
             switch (builtin.os) {
                 builtin.Os.macosx => return await (async addFileMacosx(self, file_path, value) catch unreachable),
                 builtin.Os.linux => return await (async addFileLinux(self, file_path, value) catch unreachable),
+                builtin.Os.windows => return await (async addFileWindows(self, file_path, value) catch unreachable),
                 else => @compileError("Unsupported OS"),
             }
         }
@@ -874,6 +941,8 @@ pub fn Watch(comptime V: type) type {
         }
 
         async fn addFileLinux(self: *Self, file_path: []const u8, value: V) !?V {
+            const value_copy = value;
+
             const dirname = os.path.dirname(file_path) orelse ".";
             const dirname_with_null = try std.cstr.addNullByte(self.channel.loop.allocator, dirname);
             var dirname_with_null_consumed = false;
@@ -896,7 +965,7 @@ pub fn Watch(comptime V: type) type {
 
             const gop = try self.os_data.wd_table.getOrPut(wd);
             if (!gop.found_existing) {
-                gop.kv.value = Dir{
+                gop.kv.value = OsData.Dir{
                     .dirname = dirname_with_null,
                     .file_table = OsData.FileTable.init(self.channel.loop.allocator),
                 };
@@ -907,15 +976,201 @@ pub fn Watch(comptime V: type) type {
             const file_table_gop = try dir.file_table.getOrPut(basename_with_null);
             if (file_table_gop.found_existing) {
                 const prev_value = file_table_gop.kv.value;
-                file_table_gop.kv.value = value;
+                file_table_gop.kv.value = value_copy;
                 return prev_value;
             } else {
-                file_table_gop.kv.value = value;
+                file_table_gop.kv.value = value_copy;
                 basename_with_null_consumed = true;
                 return null;
             }
         }
 
+        async fn addFileWindows(self: *Self, file_path: []const u8, value: V) !?V {
+            const value_copy = value;
+            // TODO we might need to convert dirname and basename to canonical file paths ("short"?)
+
+            const dirname = try std.mem.dupe(self.channel.loop.allocator, u8, os.path.dirname(file_path) orelse ".");
+            var dirname_consumed = false;
+            defer if (!dirname_consumed) self.channel.loop.allocator.free(dirname);
+
+            const dirname_utf16le = try std.unicode.utf8ToUtf16LeWithNull(self.channel.loop.allocator, dirname);
+            defer self.channel.loop.allocator.free(dirname_utf16le);
+
+            // TODO https://github.com/ziglang/zig/issues/265
+            const basename = os.path.basename(file_path);
+            const basename_utf16le_null = try std.unicode.utf8ToUtf16LeWithNull(self.channel.loop.allocator, basename);
+            var basename_utf16le_null_consumed = false;
+            defer if (!basename_utf16le_null_consumed) self.channel.loop.allocator.free(basename_utf16le_null);
+            const basename_utf16le_no_null = basename_utf16le_null[0..basename_utf16le_null.len-1];
+
+            const dir_handle = windows.CreateFileW(
+                dirname_utf16le.ptr,
+                windows.FILE_LIST_DIRECTORY,
+                windows.FILE_SHARE_READ | windows.FILE_SHARE_DELETE | windows.FILE_SHARE_WRITE,
+                null,
+                windows.OPEN_EXISTING,
+                windows.FILE_FLAG_BACKUP_SEMANTICS | windows.FILE_FLAG_OVERLAPPED,
+                null,
+            );
+            if (dir_handle == windows.INVALID_HANDLE_VALUE) {
+                const err = windows.GetLastError();
+                switch (err) {
+                    windows.ERROR.FILE_NOT_FOUND,
+                    windows.ERROR.PATH_NOT_FOUND,
+                    => return error.PathNotFound,
+                    else => return os.unexpectedErrorWindows(err),
+                }
+            }
+            var dir_handle_consumed = false;
+            defer if (!dir_handle_consumed) os.close(dir_handle);
+
+            const held = await (async self.os_data.table_lock.acquire() catch unreachable);
+            defer held.release();
+
+            const gop = try self.os_data.dir_table.getOrPut(dirname);
+            if (gop.found_existing) {
+                const dir = gop.kv.value;
+                const held_dir_lock = await (async dir.table_lock.acquire() catch unreachable);
+                defer held_dir_lock.release();
+
+                const file_gop = try dir.file_table.getOrPut(basename_utf16le_no_null);
+                if (file_gop.found_existing) {
+                    const prev_value = file_gop.kv.value;
+                    file_gop.kv.value = value_copy;
+                    return prev_value;
+                } else {
+                    file_gop.kv.value = value_copy;
+                    basename_utf16le_null_consumed = true;
+                    return null;
+                }
+            } else {
+                errdefer _ = self.os_data.dir_table.remove(dirname);
+                const dir = try self.channel.loop.allocator.createOne(OsData.Dir);
+                errdefer self.channel.loop.allocator.destroy(dir);
+
+                dir.* = OsData.Dir{
+                    .file_table = OsData.FileTable.init(self.channel.loop.allocator),
+                    .table_lock = event.Lock.init(self.channel.loop),
+                    .putter = undefined,
+                };
+                assert((try dir.file_table.put(basename_utf16le_no_null, value_copy)) == null);
+                basename_utf16le_null_consumed = true;
+
+                dir.putter = try async self.windowsDirReader(dir_handle, dir);
+                dir_handle_consumed = true;
+
+                dirname_consumed = true;
+
+                return null;
+            }
+        }
+
+        async fn windowsDirReader(self: *Self, dir_handle: windows.HANDLE, dir: *OsData.Dir) void {
+            // TODO https://github.com/ziglang/zig/issues/1194
+            suspend {
+                resume @handle();
+            }
+
+            self.ref();
+            defer self.deref();
+
+            defer os.close(dir_handle);
+
+            var putter_node = std.atomic.Queue(promise).Node{
+                .data = @handle(),
+                .prev = null,
+                .next = null,
+            };
+            self.os_data.all_putters.put(&putter_node);
+            defer _ = self.os_data.all_putters.remove(&putter_node);
+
+            var resume_node = Loop.ResumeNode.Basic{
+                .base = Loop.ResumeNode{
+                    .id = Loop.ResumeNode.Id.Basic,
+                    .handle = @handle(),
+                },
+            };
+            const completion_key = @ptrToInt(&resume_node.base);
+            var overlapped = windows.OVERLAPPED{
+                .Internal = 0,
+                .InternalHigh = 0,
+                .Offset = 0,
+                .OffsetHigh = 0,
+                .hEvent = null,
+            };
+            var event_buf: [4096]u8 align(@alignOf(windows.FILE_NOTIFY_INFORMATION)) = undefined;
+
+            while (true) {
+                _ = os.windowsCreateIoCompletionPort(
+                    dir_handle, self.channel.loop.os_data.io_port, completion_key, undefined,
+                ) catch |err| {
+                    await (async self.channel.put(err) catch unreachable);
+                    return;
+                };
+                {
+                    // TODO only 1 beginOneEvent for the whole coroutine
+                    self.channel.loop.beginOneEvent();
+                    errdefer self.channel.loop.finishOneEvent();
+                    suspend {
+                        _ = windows.ReadDirectoryChangesW(
+                            dir_handle,
+                            &event_buf,
+                            @intCast(windows.DWORD, event_buf.len),
+                            windows.FALSE, // watch subtree
+                            windows.FILE_NOTIFY_CHANGE_FILE_NAME        | windows.FILE_NOTIFY_CHANGE_DIR_NAME     |
+                                windows.FILE_NOTIFY_CHANGE_ATTRIBUTES   | windows.FILE_NOTIFY_CHANGE_SIZE         |
+                                windows.FILE_NOTIFY_CHANGE_LAST_WRITE   | windows.FILE_NOTIFY_CHANGE_LAST_ACCESS  |
+                                windows.FILE_NOTIFY_CHANGE_CREATION     | windows.FILE_NOTIFY_CHANGE_SECURITY,
+                            null, // number of bytes transferred (unused for async)
+                            &overlapped,
+                            null, // completion routine - unused because we use IOCP
+                        );
+                    }
+                }
+                var bytes_transferred: windows.DWORD = undefined;
+                if (windows.GetOverlappedResult(dir_handle, &overlapped, &bytes_transferred, windows.FALSE) == 0) {
+                    const errno = windows.GetLastError();
+                    const err = switch (errno) {
+                        else => os.unexpectedErrorWindows(errno),
+                    };
+                    await (async self.channel.put(err) catch unreachable);
+                } else {
+                    // can't use @bytesToSlice because of the special variable length name field
+                    var ptr = event_buf[0..].ptr;
+                    const end_ptr = ptr + bytes_transferred;
+                    var ev: *windows.FILE_NOTIFY_INFORMATION = undefined;
+                    while (@ptrToInt(ptr) < @ptrToInt(end_ptr)) : (ptr += ev.NextEntryOffset) {
+                        ev = @ptrCast(*windows.FILE_NOTIFY_INFORMATION, ptr);
+                        const emit = switch (ev.Action) {
+                            windows.FILE_ACTION_REMOVED => WatchEventId.Delete,
+                            windows.FILE_ACTION_MODIFIED => WatchEventId.CloseWrite,
+                            else => null,
+                        };
+                        if (emit) |id| {
+                            const basename_utf16le = ([*]u16)(&ev.FileName)[0..ev.FileNameLength/2];
+                            const user_value = blk: {
+                                const held = await (async dir.table_lock.acquire() catch unreachable);
+                                defer held.release();
+
+                                if (dir.file_table.get(basename_utf16le)) |entry| {
+                                    break :blk entry.value;
+                                } else {
+                                    break :blk null;
+                                }
+                            };
+                            if (user_value) |v| {
+                                await (async self.channel.put(Event{
+                                    .id = id,
+                                    .data = v,
+                                }) catch unreachable);
+                            }
+                        }
+                        if (ev.NextEntryOffset == 0) break;
+                    }
+                }
+            }
+        }
+
         pub async fn removeFile(self: *Self, file_path: []const u8) ?V {
             @panic("TODO");
         }
@@ -933,7 +1188,7 @@ pub fn Watch(comptime V: type) type {
                 .os_data = OsData{
                     .putter = @handle(),
                     .inotify_fd = inotify_fd,
-                    .wd_table = WdTable.init(loop.allocator),
+                    .wd_table = OsData.WdTable.init(loop.allocator),
                     .table_lock = event.Lock.init(loop),
                 },
             };
@@ -1065,15 +1320,15 @@ async fn testFsWatch(loop: *Loop) !void {
     const read_contents = try await try async readFile(loop, file_path, 1024 * 1024);
     assert(mem.eql(u8, read_contents, contents));
 
-    //// now watch the file
-    //var watch = try Watch(void).create(loop, 0);
-    //defer watch.destroy();
+    // now watch the file
+    var watch = try Watch(void).create(loop, 0);
+    defer watch.destroy();
 
-    //assert((try await try async watch.addFile(file_path, {})) == null);
+    assert((try await try async watch.addFile(file_path, {})) == null);
 
-    //const ev = try async watch.channel.get();
-    //var ev_consumed = false;
-    //defer if (!ev_consumed) cancel ev;
+    const ev = try async watch.channel.get();
+    var ev_consumed = false;
+    defer if (!ev_consumed) cancel ev;
 
     // overwrite line 2
     const fd = try await try async openReadWrite(loop, file_path, os.File.default_mode);
@@ -1083,11 +1338,11 @@ async fn testFsWatch(loop: *Loop) !void {
         try await try async pwritev(loop, fd, []const []const u8{"lorem ipsum"}, line2_offset);
     }
 
-    //ev_consumed = true;
-    //switch ((try await ev).id) {
-    //    WatchEventId.CloseWrite => {},
-    //    WatchEventId.Delete => @panic("wrong event"),
-    //}
+    ev_consumed = true;
+    switch ((try await ev).id) {
+        WatchEventId.CloseWrite => {},
+        WatchEventId.Delete => @panic("wrong event"),
+    }
 
     const contents_updated = try await try async readFile(loop, file_path, 1024 * 1024);
     assert(mem.eql(u8, contents_updated,
std/os/windows/kernel32.zig
@@ -11,7 +11,17 @@ pub extern "kernel32" stdcallcc fn CreateDirectoryA(
 ) BOOL;
 
 pub extern "kernel32" stdcallcc fn CreateFileA(
-    lpFileName: LPCSTR,
+    lpFileName: [*]const u8, // TODO null terminated pointer type
+    dwDesiredAccess: DWORD,
+    dwShareMode: DWORD,
+    lpSecurityAttributes: ?LPSECURITY_ATTRIBUTES,
+    dwCreationDisposition: DWORD,
+    dwFlagsAndAttributes: DWORD,
+    hTemplateFile: ?HANDLE,
+) HANDLE;
+
+pub extern "kernel32" stdcallcc fn CreateFileW(
+    lpFileName: [*]const u16, // TODO null terminated pointer type
     dwDesiredAccess: DWORD,
     dwShareMode: DWORD,
     lpSecurityAttributes: ?LPSECURITY_ATTRIBUTES,
@@ -129,6 +139,17 @@ pub extern "kernel32" stdcallcc fn QueryPerformanceCounter(lpPerformanceCount: *
 
 pub extern "kernel32" stdcallcc fn QueryPerformanceFrequency(lpFrequency: *LARGE_INTEGER) BOOL;
 
+pub extern "kernel32" stdcallcc fn ReadDirectoryChangesW(
+    hDirectory: HANDLE,
+    lpBuffer: [*]align(@alignOf(FILE_NOTIFY_INFORMATION)) u8,
+    nBufferLength: DWORD,
+    bWatchSubtree: BOOL,
+    dwNotifyFilter: DWORD,
+    lpBytesReturned: ?*DWORD,
+    lpOverlapped: ?*OVERLAPPED,
+    lpCompletionRoutine: LPOVERLAPPED_COMPLETION_ROUTINE,
+) BOOL;
+
 pub extern "kernel32" stdcallcc fn ReadFile(
     in_hFile: HANDLE,
     out_lpBuffer: [*]u8,
@@ -168,3 +189,30 @@ pub extern "kernel32" stdcallcc fn WriteFileEx(hFile: HANDLE, lpBuffer: [*]const
 pub extern "kernel32" stdcallcc fn LoadLibraryA(lpLibFileName: LPCSTR) ?HMODULE;
 
 pub extern "kernel32" stdcallcc fn FreeLibrary(hModule: HMODULE) BOOL;
+
+
+pub const FILE_NOTIFY_INFORMATION = extern struct {
+    NextEntryOffset: DWORD,
+    Action: DWORD,
+    FileNameLength: DWORD,
+    FileName: [1]WCHAR,
+};
+
+pub const FILE_ACTION_ADDED = 0x00000001;
+pub const FILE_ACTION_REMOVED = 0x00000002;
+pub const FILE_ACTION_MODIFIED = 0x00000003;
+pub const FILE_ACTION_RENAMED_OLD_NAME = 0x00000004;
+pub const FILE_ACTION_RENAMED_NEW_NAME = 0x00000005;
+
+pub const LPOVERLAPPED_COMPLETION_ROUTINE = ?extern fn(DWORD, DWORD, *OVERLAPPED) void;
+
+pub const FILE_LIST_DIRECTORY = 1;
+
+pub const FILE_NOTIFY_CHANGE_CREATION = 64;
+pub const FILE_NOTIFY_CHANGE_SIZE = 8;
+pub const FILE_NOTIFY_CHANGE_SECURITY = 256;
+pub const FILE_NOTIFY_CHANGE_LAST_ACCESS = 32;
+pub const FILE_NOTIFY_CHANGE_LAST_WRITE = 16;
+pub const FILE_NOTIFY_CHANGE_DIR_NAME = 2;
+pub const FILE_NOTIFY_CHANGE_FILE_NAME = 1;
+pub const FILE_NOTIFY_CHANGE_ATTRIBUTES = 4;
std/unicode.zig
@@ -188,6 +188,7 @@ pub const Utf8View = struct {
         return Utf8View{ .bytes = s };
     }
 
+    /// TODO: https://github.com/ziglang/zig/issues/425
     pub fn initComptime(comptime s: []const u8) Utf8View {
         if (comptime init(s)) |r| {
             return r;
@@ -199,7 +200,7 @@ pub const Utf8View = struct {
         }
     }
 
-    pub fn iterator(s: *const Utf8View) Utf8Iterator {
+    pub fn iterator(s: Utf8View) Utf8Iterator {
         return Utf8Iterator{
             .bytes = s.bytes,
             .i = 0,
@@ -530,3 +531,20 @@ test "utf16leToUtf8" {
         assert(mem.eql(u8, utf8, "\xf4\x8f\xb0\x80"));
     }
 }
+
+/// TODO support codepoints bigger than 16 bits
+/// TODO type for null terminated pointer
+pub fn utf8ToUtf16LeWithNull(allocator: *mem.Allocator, utf8: []const u8) ![]u16 {
+    var result = std.ArrayList(u16).init(allocator);
+    // optimistically guess that it will not require surrogate pairs
+    try result.ensureCapacity(utf8.len + 1);
+
+    const view = try Utf8View.init(utf8);
+    var it = view.iterator();
+    while (it.nextCodepoint()) |codepoint| {
+        try result.append(@intCast(u16, codepoint)); // TODO surrogate pairs
+    }
+
+    try result.append(0);
+    return result.toOwnedSlice();
+}