Commit 834609038c

Luuk de Gram <luuk@degram.dev>
2023-06-21 21:44:53
std: implement `join` for WASI-threads
We now reset the Thread ID to 0 and wake up the main thread listening for the thread to finish. We use inline assembly as we cannot use the stack to set the thread ID as it could possibly clobber any of the memory. Currently, we leak the memory that was allocated for the thread. We need to implement a way where we can clean up the memory without using the stack (as the stack is stored inside this same memory).
1 parent 10bf58b
Changed files (1)
lib
lib/std/Thread.zig
@@ -790,7 +790,40 @@ const WasiThreadImpl = struct {
     }
 
     fn join(self: Impl) void {
-        _ = self;
+        // TODO cleanup memory
+        // The memory also contains the thread's stack, which is problematic while freeing the memory
+        // defer self.thread.allocator.free(self.thread.memory);
+
+        var spin: u8 = 10;
+        while (true) {
+            const tid = self.thread.tid.load(.SeqCst);
+            if (tid == 0) {
+                break;
+            }
+
+            if (spin > 0) {
+                spin -= 1;
+                std.atomic.spinLoopHint();
+                continue;
+            }
+
+            const result = asm (
+                \\local.get %[ptr]
+                \\local.get %[expected]
+                \\i64.const -1 # infinite
+                \\memory.atomic.wait32 0
+                \\local.set %[ret]
+                : [ret] "=r" (-> u32),
+                : [ptr] "r" (&self.thread.tid.value),
+                  [expected] "r" (tid),
+            );
+            switch (result) {
+                0 => continue, // ok
+                1 => continue, // expected =! loaded
+                2 => unreachable, // timeout (infinite)
+                else => unreachable,
+            }
+        }
     }
 
     fn spawn(config: std.Thread.SpawnConfig, comptime f: anytype, args: anytype) !WasiThreadImpl {
@@ -868,25 +901,43 @@ const WasiThreadImpl = struct {
         return .{ .thread = &instance.thread };
     }
 
-    export fn wasi_thread_start(tid: i32, arg: *const Instance) void {
+    /// Bootstrap procedure, called by the HOST environment after thread creation.
+    export fn wasi_thread_start(tid: i32, arg: *Instance) void {
         __set_stack_pointer(arg.thread.memory.ptr + arg.stack_pointer);
         __wasm_init_tls(arg.thread.memory.ptr + arg.tls_base);
         WasiThreadImpl.tls_thread_id = @intCast(u32, tid);
 
-        // finished bootstrapping, call user's procedure.
+        // Finished bootstrapping, call user's procedure.
         arg.call_back(arg.raw_ptr);
+
+        // Thread finished. Reset Thread ID and wake up the main thread if needed.
+        // We use inline assembly here as we must ensure not to use the stack.
+        asm volatile (
+            \\ local.get %[ptr]
+            \\ i32.const 0
+            \\ i32.atomic.store 0
+            :
+            : [ptr] "r" (&arg.thread.tid.value),
+        );
+        asm volatile (
+            \\ local.get %[ptr]
+            \\ i32.const 1 # waiters
+            \\ memory.atomic.notify 0
+            \\ drop # no need to know the waiters
+            :
+            : [ptr] "r" (&arg.thread.tid.value),
+        );
     }
 
-    // Asks the host to create a new thread for us.
-    // Newly created thread wil lcall `wasi_tread_start` with the thread ID as well
-    // as the input `arg` that was provided to `spawnWasiThread`
+    /// Asks the host to create a new thread for us.
+    /// Newly created thread will call `wasi_tread_start` with the thread ID as well
+    /// as the input `arg` that was provided to `spawnWasiThread`
     const spawnWasiThread = @"thread-spawn";
-    extern "wasi" fn @"thread-spawn"(arg: *const Instance) i32;
+    extern "wasi" fn @"thread-spawn"(arg: *Instance) i32;
 
     /// Initializes the TLS data segment starting at `memory`.
     /// This is a synthetic function, generated by the linker.
     extern fn __wasm_init_tls(memory: [*]u8) void;
-    extern fn __set_stack_pointer(ptr: [*]u8) void;
 
     /// Returns a pointer to the base of the TLS data segment for the current thread
     inline fn __tls_base() [*]u8 {