Commit e06ab1b010

Luuk de Gram <luuk@degram.dev>
2023-06-23 19:14:55
std: implement `detach` for WASI-threads
When a thread is detached from the main thread, we automatically cleanup any allocated memory. For this we first reset the stack-pointer to the original stack-pointer of the main-thread so we can safely clear the memory which also contains the thread's stack.
1 parent 622b7c4
Changed files (2)
lib
lib/std/Thread/Futex.zig
@@ -453,7 +453,7 @@ const WasmImpl = struct {
         if (!comptime std.Target.wasm.featureSetHas(builtin.target.cpu.features, .atomics)) {
             @compileError("WASI target missing cpu feature 'atomics'");
         }
-        const to: i64 = if (timeout) |to| @intCast(i64, to) else -1;
+        const to: i64 = if (timeout) |to| @intCast(to) else -1;
         const result = asm (
             \\local.get %[ptr]
             \\local.get %[expected]
@@ -462,7 +462,7 @@ const WasmImpl = struct {
             \\local.set %[ret]
             : [ret] "=r" (-> u32),
             : [ptr] "r" (&ptr.value),
-              [expected] "r" (@bitCast(i32, expect)),
+              [expected] "r" (@as(i32, @bitCast(expect))),
               [timeout] "r" (to),
         );
         switch (result) {
lib/std/Thread.zig
@@ -757,6 +757,8 @@ const WasiThreadImpl = struct {
         /// The allocator used to allocate the thread's memory,
         /// which is also used during `join` to ensure clean-up.
         allocator: std.mem.Allocator,
+        /// The current state of the thread.
+        state: State = State.init(.running),
     };
 
     /// A meta-data structure used to bootstrap a thread
@@ -775,8 +777,15 @@ const WasiThreadImpl = struct {
         /// function upon thread spawn. The above mentioned pointer will be passed
         /// to this function pointer as its argument.
         call_back: *const fn (usize) void,
+        /// When a thread is in `detached` state, we must free all of its memory
+        /// upon thread completion. However, as this is done while still within
+        /// the thread, we must first jump back to the main thread's stack or else
+        /// we end up freeing the stack that we're currently using.
+        original_stack_pointer: [*]u8,
     };
 
+    const State = Atomic(enum(u8) { running, completed, detached });
+
     fn getCurrentId() Id {
         return tls_thread_id;
     }
@@ -786,7 +795,11 @@ const WasiThreadImpl = struct {
     }
 
     fn detach(self: Impl) void {
-        _ = self;
+        switch (self.thread.state.swap(.detached, .SeqCst)) {
+            .running => {},
+            .completed => self.join(),
+            .detached => unreachable,
+        }
     }
 
     fn join(self: Impl) void {
@@ -836,7 +849,7 @@ const WasiThreadImpl = struct {
         const Wrapper = struct {
             args: @TypeOf(args),
             fn entry(ptr: usize) void {
-                const w = @intToPtr(*@This(), ptr);
+                const w: *@This() = @ptrFromInt(ptr);
                 @call(.auto, f, w.args);
             }
         };
@@ -854,7 +867,7 @@ const WasiThreadImpl = struct {
             // start with atleast a single page, which is used as a guard to prevent
             // other threads clobbering our new thread.
             // Unfortunately, WebAssembly has no notion of read-only segments, so this
-            // is only a temporary measure until the entire page is "run over".
+            // is only a best effort.
             var bytes: usize = std.wasm.page_size;
 
             bytes = std.mem.alignForward(usize, bytes, 16); // align stack to 16 bytes
@@ -880,16 +893,17 @@ const WasiThreadImpl = struct {
         // Allocate the amount of memory required for all meta data.
         const allocated_memory = try config.allocator.?.alloc(u8, map_bytes);
 
-        const wrapper = @ptrCast(*Wrapper, @alignCast(@alignOf(Wrapper), &allocated_memory[wrapper_offset]));
+        const wrapper: *Wrapper = @ptrCast(@alignCast(&allocated_memory[wrapper_offset]));
         wrapper.* = .{ .args = args };
 
-        const instance = @ptrCast(*Instance, @alignCast(@alignOf(Instance), &allocated_memory[instance_offset]));
+        const instance: *Instance = @ptrCast(@alignCast(&allocated_memory[instance_offset]));
         instance.* = .{
             .thread = .{ .memory = allocated_memory, .allocator = config.allocator.? },
             .tls_offset = tls_offset,
             .stack_offset = stack_offset,
-            .raw_ptr = @ptrToInt(wrapper),
+            .raw_ptr = @intFromPtr(wrapper),
             .call_back = &Wrapper.entry,
+            .original_stack_pointer = __get_stack_pointer(),
         };
 
         const tid = spawnWasiThread(instance);
@@ -903,32 +917,46 @@ const WasiThreadImpl = struct {
         return .{ .thread = &instance.thread };
     }
 
-    /// Bootstrap procedure, called by the HOST environment after thread creation.
+    /// Bootstrap procedure, called by the host environment after thread creation.
     export fn wasi_thread_start(tid: i32, arg: *Instance) void {
         __set_stack_pointer(arg.thread.memory.ptr + arg.stack_offset);
         __wasm_init_tls(arg.thread.memory.ptr + arg.tls_offset);
-        WasiThreadImpl.tls_thread_id = @intCast(u32, tid);
+        @atomicStore(u32, &WasiThreadImpl.tls_thread_id, @intCast(tid), .SeqCst);
 
         // Finished bootstrapping, call user's procedure.
         arg.call_back(arg.raw_ptr);
 
-        // Thread finished. Reset Thread ID and wake up the main thread if needed.
-        // We use inline assembly here as we must ensure not to use the stack.
-        asm volatile (
-            \\ local.get %[ptr]
-            \\ i32.const 0
-            \\ i32.atomic.store 0
-            :
-            : [ptr] "r" (&arg.thread.tid.value),
-        );
-        asm volatile (
-            \\ local.get %[ptr]
-            \\ i32.const 1 # waiters
-            \\ memory.atomic.notify 0
-            \\ drop # no need to know the waiters
-            :
-            : [ptr] "r" (&arg.thread.tid.value),
-        );
+        switch (arg.thread.state.swap(.completed, .SeqCst)) {
+            .running => {
+                // reset the Thread ID
+                asm volatile (
+                    \\ local.get %[ptr]
+                    \\ i32.const 0
+                    \\ i32.atomic.store 0
+                    :
+                    : [ptr] "r" (&arg.thread.tid.value),
+                );
+
+                // Wake the main thread listening to this thread
+                asm volatile (
+                    \\ local.get %[ptr]
+                    \\ i32.const 1 # waiters
+                    \\ memory.atomic.notify 0
+                    \\ drop # no need to know the waiters
+                    :
+                    : [ptr] "r" (&arg.thread.tid.value),
+                );
+            },
+            .completed => unreachable,
+            .detached => {
+                // restore the original stack pointer so we can free the memory
+                // without having to worry about freeing the stack
+                __set_stack_pointer(arg.original_stack_pointer);
+                // Ensure a copy so we don't free the allocator reference itself
+                var allocator = arg.thread.allocator;
+                allocator.free(arg.thread.memory);
+            },
+        }
     }
 
     /// Asks the host to create a new thread for us.
@@ -980,6 +1008,15 @@ const WasiThreadImpl = struct {
             : [ptr] "r" (addr),
         );
     }
+
+    /// Returns the current value of the stack pointer
+    inline fn __get_stack_pointer() [*]u8 {
+        return asm (
+            \\ global.get __stack_pointer
+            \\ local.set %[stack_ptr]
+            : [stack_ptr] "=r" (-> [*]u8),
+        );
+    }
 };
 
 const LinuxThreadImpl = struct {