Commit 6376d96824

Andrew Kelley <superjoe30@gmail.com>
2018-04-29 08:40:22
support kernel threads for windows
* remove std.os.spawnThreadAllocator - windows does not support an explicit stack, so using an allocator for a thread stack space does not work. * std.os.spawnThread - instead of accepting a stack argument, the implementation will directly allocate using OS-specific APIs.
1 parent bf8e419
Changed files (6)
std/atomic/queue.zig
@@ -53,10 +53,6 @@ const puts_per_thread = 10000;
 const put_thread_count = 3;
 
 test "std.atomic.queue" {
-    if (builtin.os == builtin.Os.windows) {
-        // TODO implement kernel threads for windows
-        return;
-    }
     var direct_allocator = std.heap.DirectAllocator.init();
     defer direct_allocator.deinit();
 
@@ -79,11 +75,11 @@ test "std.atomic.queue" {
 
     var putters: [put_thread_count]&std.os.Thread = undefined;
     for (putters) |*t| {
-        *t = try std.os.spawnThreadAllocator(a, &context, startPuts);
+        *t = try std.os.spawnThread(&context, startPuts);
     }
     var getters: [put_thread_count]&std.os.Thread = undefined;
     for (getters) |*t| {
-        *t = try std.os.spawnThreadAllocator(a, &context, startGets);
+        *t = try std.os.spawnThread(&context, startGets);
     }
 
     for (putters) |t| t.wait();
std/atomic/stack.zig
@@ -60,10 +60,6 @@ const puts_per_thread = 1000;
 const put_thread_count = 3;
 
 test "std.atomic.stack" {
-    if (builtin.os == builtin.Os.windows) {
-        // TODO implement kernel threads for windows
-        return;
-    }
     var direct_allocator = std.heap.DirectAllocator.init();
     defer direct_allocator.deinit();
 
@@ -85,11 +81,11 @@ test "std.atomic.stack" {
 
     var putters: [put_thread_count]&std.os.Thread = undefined;
     for (putters) |*t| {
-        *t = try std.os.spawnThreadAllocator(a, &context, startPuts);
+        *t = try std.os.spawnThread(&context, startPuts);
     }
     var getters: [put_thread_count]&std.os.Thread = undefined;
     for (getters) |*t| {
-        *t = try std.os.spawnThreadAllocator(a, &context, startGets);
+        *t = try std.os.spawnThread(&context, startGets);
     }
 
     for (putters) |t| t.wait();
std/os/windows/index.zig
@@ -28,6 +28,9 @@ pub extern "kernel32" stdcallcc fn CreateProcessA(lpApplicationName: ?LPCSTR, lp
 pub extern "kernel32" stdcallcc fn CreateSymbolicLinkA(lpSymlinkFileName: LPCSTR, lpTargetFileName: LPCSTR,
     dwFlags: DWORD) BOOLEAN;
 
+
+pub extern "kernel32" stdcallcc fn CreateThread(lpThreadAttributes: ?LPSECURITY_ATTRIBUTES, dwStackSize: SIZE_T, lpStartAddress: LPTHREAD_START_ROUTINE, lpParameter: ?LPVOID, dwCreationFlags: DWORD, lpThreadId: ?LPDWORD) ?HANDLE;
+
 pub extern "kernel32" stdcallcc fn DeleteFileA(lpFileName: LPCSTR) BOOL;
 
 pub extern "kernel32" stdcallcc fn ExitProcess(exit_code: UINT) noreturn;
@@ -318,6 +321,9 @@ pub const HEAP_CREATE_ENABLE_EXECUTE = 0x00040000;
 pub const HEAP_GENERATE_EXCEPTIONS = 0x00000004;
 pub const HEAP_NO_SERIALIZE = 0x00000001;
 
+pub const PTHREAD_START_ROUTINE = extern fn(LPVOID) DWORD;
+pub const LPTHREAD_START_ROUTINE = PTHREAD_START_ROUTINE;
+
 test "import" {
     _ = @import("util.zig");
 }
std/os/index.zig
@@ -2347,18 +2347,30 @@ pub fn posixGetSockOptConnectError(sockfd: i32) PosixConnectError!void {
 }
 
 pub const Thread = struct {
-    pid: pid_t,
-    allocator: ?&mem.Allocator,
-    stack: []u8,
-    pthread_handle: pthread_t,
+    data: Data,
 
     pub const use_pthreads = is_posix and builtin.link_libc;
-    const pthread_t = if (use_pthreads) c.pthread_t else void;
-    const pid_t = if (!use_pthreads) i32 else void;
+    const Data = if (use_pthreads) struct {
+      handle: c.pthread_t,
+      stack_addr: usize,
+      stack_len: usize,
+    } else switch (builtin.os) {
+        builtin.Os.linux => struct {
+            pid: i32,
+            stack_addr: usize,
+            stack_len: usize,
+        },
+        builtin.Os.windows => struct {
+            handle: windows.HANDLE,
+            alloc_start: &c_void,
+            heap_handle: windows.HANDLE,
+        },
+        else => @compileError("Unsupported OS"),
+    };
 
     pub fn wait(self: &const Thread) void {
         if (use_pthreads) {
-            const err = c.pthread_join(self.pthread_handle, null);
+            const err = c.pthread_join(self.data.handle, null);
             switch (err) {
                 0 => {},
                 posix.EINVAL => unreachable,
@@ -2366,23 +2378,27 @@ pub const Thread = struct {
                 posix.EDEADLK => unreachable,
                 else => unreachable,
             }
-        } else if (builtin.os == builtin.Os.linux) {
-            while (true) {
-                const pid_value = @atomicLoad(i32, &self.pid, builtin.AtomicOrder.SeqCst);
-                if (pid_value == 0) break;
-                const rc = linux.futex_wait(@ptrToInt(&self.pid), linux.FUTEX_WAIT, pid_value, null);
-                switch (linux.getErrno(rc)) {
-                    0 => continue,
-                    posix.EINTR => continue,
-                    posix.EAGAIN => continue,
-                    else => unreachable,
+            assert(posix.munmap(self.data.stack_addr, self.data.stack_len) == 0);
+        } else switch (builtin.os) {
+            builtin.Os.linux => {
+                while (true) {
+                    const pid_value = @atomicLoad(i32, &self.data.pid, builtin.AtomicOrder.SeqCst);
+                    if (pid_value == 0) break;
+                    const rc = linux.futex_wait(@ptrToInt(&self.data.pid), linux.FUTEX_WAIT, pid_value, null);
+                    switch (linux.getErrno(rc)) {
+                        0 => continue,
+                        posix.EINTR => continue,
+                        posix.EAGAIN => continue,
+                        else => unreachable,
+                    }
                 }
-            }
-        } else {
-            @compileError("Unsupported OS");
-        }
-        if (self.allocator) |a| {
-            a.free(self.stack);
+                assert(posix.munmap(self.data.stack_addr, self.data.stack_len) == 0);
+            },
+            builtin.Os.windows => {
+                assert(windows.WaitForSingleObject(self.data.handle, windows.INFINITE) == windows.WAIT_OBJECT_0);
+                assert(windows.HeapFree(self.data.heap_handle, 0, self.data.alloc_start) != 0);
+            },
+            else => @compileError("Unsupported OS"),
         }
     }
 };
@@ -2407,52 +2423,60 @@ pub const SpawnThreadError = error {
     /// be copied.
     SystemResources,
 
-    /// pthreads requires at least 16384 bytes of stack space
-    StackTooSmall,
+    /// Not enough userland memory to spawn the thread.
+    OutOfMemory,
 
     Unexpected,
 };
 
-pub const SpawnThreadAllocatorError = SpawnThreadError || error{OutOfMemory};
-
 /// caller must call wait on the returned thread
 /// fn startFn(@typeOf(context)) T
 /// where T is u8, noreturn, void, or !void
-pub fn spawnThreadAllocator(allocator: &mem.Allocator, context: var, comptime startFn: var) SpawnThreadAllocatorError!&Thread {
+/// caller must call wait on the returned thread
+pub fn spawnThread(context: var, comptime startFn: var) SpawnThreadError!&Thread {
     // TODO compile-time call graph analysis to determine stack upper bound
     // https://github.com/zig-lang/zig/issues/157
     const default_stack_size = 8 * 1024 * 1024;
-    const stack_bytes = try allocator.alignedAlloc(u8, os.page_size, default_stack_size);
-    const thread = try spawnThread(stack_bytes, context, startFn);
-    thread.allocator = allocator;
-    return thread;
-}
 
-/// stack must be big enough to store one Thread and one @typeOf(context), each with default alignment, at the end
-/// fn startFn(@typeOf(context)) T
-/// where T is u8, noreturn, void, or !void
-/// caller must call wait on the returned thread
-pub fn spawnThread(stack: []align(os.page_size) u8, context: var, comptime startFn: var) SpawnThreadError!&Thread {
     const Context = @typeOf(context);
     comptime assert(@ArgType(@typeOf(startFn), 0) == Context);
 
-    var stack_end: usize = @ptrToInt(stack.ptr) + stack.len;
-    var arg: usize = undefined;
-    if (@sizeOf(Context) != 0) {
-        stack_end -= @sizeOf(Context);
-        stack_end -= stack_end % @alignOf(Context);
-        assert(stack_end >= @ptrToInt(stack.ptr));
-        const context_ptr = @alignCast(@alignOf(Context), @intToPtr(&Context, stack_end));
-        *context_ptr = context;
-        arg = stack_end;
-    }
+    if (builtin.os == builtin.Os.windows) {
+        const WinThread = struct {
+            const OuterContext = struct {
+                thread: Thread,
+                inner: Context,
+            };
+            extern fn threadMain(arg: windows.LPVOID) windows.DWORD {
+                if (@sizeOf(Context) == 0) {
+                    return startFn({});
+                } else {
+                    return startFn(*@ptrCast(&Context, @alignCast(@alignOf(Context), arg)));
+                }
+            }
+        };
 
-    stack_end -= @sizeOf(Thread);
-    stack_end -= stack_end % @alignOf(Thread);
-    assert(stack_end >= @ptrToInt(stack.ptr));
-    const thread_ptr = @alignCast(@alignOf(Thread), @intToPtr(&Thread, stack_end));
-    thread_ptr.stack = stack;
-    thread_ptr.allocator = null;
+        const heap_handle = windows.GetProcessHeap() ?? return SpawnThreadError.OutOfMemory;
+        const byte_count = @alignOf(WinThread.OuterContext) + @sizeOf(WinThread.OuterContext);
+        const bytes_ptr = windows.HeapAlloc(heap_handle, 0, byte_count) ?? return SpawnThreadError.OutOfMemory;
+        errdefer assert(windows.HeapFree(heap_handle, 0, bytes_ptr) != 0);
+        const bytes = @ptrCast(&u8, bytes_ptr)[0..byte_count];
+        const outer_context = std.heap.FixedBufferAllocator.init(bytes).allocator.create(WinThread.OuterContext) catch unreachable;
+        outer_context.inner = context;
+        outer_context.thread.data.heap_handle = heap_handle;
+        outer_context.thread.data.alloc_start = bytes_ptr;
+
+        const parameter = if (@sizeOf(Context) == 0) null else @ptrCast(&c_void, &outer_context.inner);
+        outer_context.thread.data.handle = windows.CreateThread(null, default_stack_size, WinThread.threadMain,
+            parameter, 0, null) ??
+        {
+            const err = windows.GetLastError();
+            return switch (err) {
+                else => os.unexpectedErrorWindows(err),
+            };
+        };
+        return &outer_context.thread;
+    }
 
     const MainFuncs = struct {
         extern fn linuxThreadMain(ctx_addr: usize) u8 {
@@ -2473,6 +2497,29 @@ pub fn spawnThread(stack: []align(os.page_size) u8, context: var, comptime start
         }
     };
 
+    const stack_len = default_stack_size;
+    const stack_addr = posix.mmap(null, stack_len, posix.PROT_READ|posix.PROT_WRITE, 
+            posix.MAP_PRIVATE|posix.MAP_ANONYMOUS|posix.MAP_GROWSDOWN, -1, 0);
+    if (stack_addr == posix.MAP_FAILED) return error.OutOfMemory;
+    errdefer _ = posix.munmap(stack_addr, stack_len);
+
+    var stack_end: usize = stack_addr + stack_len;
+    var arg: usize = undefined;
+    if (@sizeOf(Context) != 0) {
+        stack_end -= @sizeOf(Context);
+        stack_end -= stack_end % @alignOf(Context);
+        assert(stack_end >= stack_addr);
+        const context_ptr = @alignCast(@alignOf(Context), @intToPtr(&Context, stack_end));
+        *context_ptr = context;
+        arg = stack_end;
+    }
+
+    stack_end -= @sizeOf(Thread);
+    stack_end -= stack_end % @alignOf(Thread);
+    assert(stack_end >= stack_addr);
+    const thread_ptr = @alignCast(@alignOf(Thread), @intToPtr(&Thread, stack_end));
+
+
     if (builtin.os == builtin.Os.windows) {
         // use windows API directly
         @compileError("TODO support spawnThread for Windows");
@@ -2484,14 +2531,12 @@ pub fn spawnThread(stack: []align(os.page_size) u8, context: var, comptime start
 
         // align to page
         stack_end -= stack_end % os.page_size;
+        assert(c.pthread_attr_setstack(&attr, @intToPtr(&c_void, stack_addr), stack_len) == 0);
 
-        const stack_size = stack_end - @ptrToInt(stack.ptr);
-        const setstack_err = c.pthread_attr_setstack(&attr, @ptrCast(&c_void, stack.ptr), stack_size);
-        if (setstack_err != 0) {
-            return SpawnThreadError.StackTooSmall; // pthreads requires at least 16384 bytes
-        }
+        thread_ptr.data.stack_addr = stack_addr;
+        thread_ptr.data.stack_len = stack_len;
 
-        const err = c.pthread_create(&thread_ptr.pthread_handle, &attr, MainFuncs.posixThreadMain, @intToPtr(&c_void, arg));
+        const err = c.pthread_create(&thread_ptr.data.handle, &attr, MainFuncs.posixThreadMain, @intToPtr(&c_void, arg));
         switch (err) {
             0 => return thread_ptr,
             posix.EAGAIN => return SpawnThreadError.SystemResources,
std/os/test.zig
@@ -44,24 +44,12 @@ test "access file" {
 }
 
 test "spawn threads" {
-    if (builtin.os == builtin.Os.windows) {
-        // TODO implement threads on windows
-        return;
-    }
-
-    var direct_allocator = std.heap.DirectAllocator.init();
-    defer direct_allocator.deinit();
-
     var shared_ctx: i32 = 1;
 
-    const thread1 = try std.os.spawnThreadAllocator(&direct_allocator.allocator, {}, start1);
-    const thread4 = try std.os.spawnThreadAllocator(&direct_allocator.allocator, &shared_ctx, start2);
-
-    var stack1: [20 * 1024]u8 align(os.page_size) = undefined;
-    var stack2: [20 * 1024]u8 align(os.page_size) = undefined;
-
-    const thread2 = try std.os.spawnThread(stack1[0..], &shared_ctx, start2);
-    const thread3 = try std.os.spawnThread(stack2[0..], &shared_ctx, start2);
+    const thread1 = try std.os.spawnThread({}, start1);
+    const thread2 = try std.os.spawnThread(&shared_ctx, start2);
+    const thread3 = try std.os.spawnThread(&shared_ctx, start2);
+    const thread4 = try std.os.spawnThread(&shared_ctx, start2);
 
     thread1.wait();
     thread2.wait();
std/mem.zig
@@ -32,6 +32,7 @@ pub const Allocator = struct {
     freeFn: fn (self: &Allocator, old_mem: []u8) void,
 
     fn create(self: &Allocator, comptime T: type) !&T {
+        if (@sizeOf(T) == 0) return &{};
         const slice = try self.alloc(T, 1);
         return &slice[0];
     }