Commit a97dbdfa0b

Luuk de Gram <luuk@degram.dev>
2023-06-20 22:19:51
std: implement `Thread` `spawn` for WASI
This implements a first version to spawn a WASI-thread. For a new thread to be created, we calculate the size required to store TLS, the new stack, and metadata. This size is then allocated using a user-provided allocator. After a new thread is spawn, the HOST will call into our bootstrap procedure. This bootstrap procedure will then initialize the TLS segment and set the newly spawned thread's TID. It will also set the stack pointer to the newly created stack to ensure we do not clobber the main thread's stack. When bootstrapping the thread is completed, we will call the user's function on this new thread.
1 parent ea0d4c8
Changed files (1)
lib
lib/std/Thread.zig
@@ -28,6 +28,8 @@ else if (use_pthreads)
     PosixThreadImpl
 else if (target.os.tag == .linux)
     LinuxThreadImpl
+else if (target.os.tag == .wasi)
+    WasiThreadImpl
 else
     UnsupportedImpl;
 
@@ -266,6 +268,7 @@ pub const Id = switch (target.os.tag) {
     .freebsd,
     .openbsd,
     .haiku,
+    .wasi,
     => u32,
     .macos, .ios, .watchos, .tvos => u64,
     .windows => os.windows.DWORD,
@@ -296,6 +299,8 @@ pub const SpawnConfig = struct {
 
     /// Size in bytes of the Thread's stack
     stack_size: usize = 16 * 1024 * 1024,
+    /// The allocator to be used to allocate memory for the to-be-spawned thread
+    allocator: ?std.mem.Allocator = null,
 };
 
 pub const SpawnError = error{
@@ -733,6 +738,193 @@ const PosixThreadImpl = struct {
     }
 };
 
+const WasiThreadImpl = struct {
+    comptime {
+        // Sets the stack pointer, which is needed after creating a new thread
+        // to ensure the stack of the main thread isn't being poluted.
+        asm (
+            \\ .text
+            \\ .export_name	__set_stack_pointer, __set_stack_pointer
+            \\ .globaltype __stack_pointer, i32
+            \\ .hidden wasi_thread_start
+            \\ .globl wasi_thread_start
+            \\ .type __set_stack_pointer, @function
+            \\
+            \\ __set_stack_pointer:
+            \\	  .functype	__set_stack_pointer (i32) -> ()
+            \\    local.get 0 # The raw pointer which replaces the stack pointer
+            \\    global.set __stack_pointer
+            \\    end_function
+        );
+    }
+    thread: *WasiThread,
+
+    pub const ThreadHandle = i32;
+    threadlocal var tls_thread_id: Id = 0;
+
+    const WasiThread = struct {
+        tid: Atomic(i32) = Atomic(i32).init(0),
+        memory: []u8,
+    };
+
+    /// A meta-data structure used to bootstrap a thread
+    const Instance = struct {
+        thread: WasiThread,
+        /// Address of this `Instance`
+        base: usize,
+        /// Contains the pointer of the new __tls_base.
+        tls_base: usize,
+        /// Contains the pointer to the stack for the newly spawned thread.
+        stack_pointer: usize,
+        /// Contains the pointer to the wrapper which holds all arguments
+        /// for the callback.
+        raw_ptr: usize,
+        /// Function pointer to a wrapping function which will call the user's
+        /// function upon thread spawn. The above mentioned pointer will be passed
+        /// to this function pointer as its argument.
+        call_back: *const fn (usize) void,
+    };
+
+    fn getCurrentId() Id {
+        return tls_thread_id;
+    }
+
+    fn getHandle(self: Impl) ThreadHandle {
+        return self.thread.tid;
+    }
+
+    fn detach(self: Impl) void {
+        _ = self;
+    }
+
+    fn join(self: Impl) void {
+        _ = self;
+    }
+
+    fn spawn(config: std.Thread.SpawnConfig, comptime f: anytype, args: anytype) !WasiThreadImpl {
+        if (config.allocator == null) return error.OutOfMemory; // an allocator is required to spawn a WASI-thread
+
+        // Wrapping struct required to hold the user-provided function arguments.
+        const Wrapper = struct {
+            args: @TypeOf(args),
+            fn entry(ptr: usize) void {
+                const w = @intToPtr(*@This(), ptr);
+                @call(.auto, f, w.args);
+            }
+        };
+
+        var guard_offset: usize = undefined;
+        var stack_offset: usize = undefined;
+        var tls_offset: usize = undefined;
+        var wrapper_offset: usize = undefined;
+        var instance_offset: usize = undefined;
+
+        // Calculate the bytes we have to allocate to store all thread information, including:
+        // - The actual stack for the thread
+        // - The TLS segment
+        // - `Instance` - containing information about how to call the user's function.
+        const map_bytes = blk: {
+            var bytes: usize = std.wasm.page_size;
+            guard_offset = bytes;
+
+            bytes = std.mem.alignForward(usize, bytes, 16); // align stack to 16 bytes
+            stack_offset = bytes;
+            bytes += @max(std.wasm.page_size, config.stack_size);
+
+            bytes = std.mem.alignForward(usize, bytes, __tls_align());
+            tls_offset = bytes;
+            bytes += __tls_size();
+
+            bytes = std.mem.alignForward(usize, bytes, @alignOf(Wrapper));
+            wrapper_offset = bytes;
+            bytes += @sizeOf(Wrapper);
+
+            bytes = std.mem.alignForward(usize, bytes, @alignOf(Instance));
+            instance_offset = bytes;
+            bytes += @sizeOf(Instance);
+
+            bytes = std.mem.alignForward(usize, bytes, std.wasm.page_size);
+            break :blk bytes;
+        };
+
+        // Allocate the amount of memory required for all meta data.
+        const allocated_memory = try config.allocator.?.alloc(u8, map_bytes);
+
+        const wrapper = @ptrCast(*Wrapper, @alignCast(@alignOf(Wrapper), &allocated_memory[wrapper_offset]));
+        wrapper.* = .{ .args = args };
+
+        const instance = @ptrCast(*Instance, @alignCast(@alignOf(Instance), &allocated_memory[instance_offset]));
+        instance.* = .{
+            .thread = .{ .memory = allocated_memory },
+            .base = @ptrToInt(allocated_memory.ptr),
+            .tls_base = tls_offset,
+            .stack_pointer = stack_offset,
+            .raw_ptr = @ptrToInt(wrapper),
+            .call_back = &Wrapper.entry,
+        };
+
+        const tid = spawnWasiThread(instance);
+        // The specification says any value lower than 0 indicates an error.
+        // The values of such error are unspecified. WASI-Libc treats it as EAGAIN.
+        if (tid < 0) {
+            return error.SystemResources;
+        }
+        instance.thread.tid.store(tid, .SeqCst);
+
+        return .{ .thread = &instance.thread };
+    }
+
+    export fn wasi_thread_start(tid: i32, arg: *const Instance) void {
+        __set_stack_pointer(arg.thread.memory.ptr + arg.stack_pointer);
+        __wasm_init_tls(arg.thread.memory.ptr + arg.tls_base);
+        WasiThreadImpl.tls_thread_id = @intCast(u32, tid);
+
+        // finished bootstrapping, call user's procedure.
+        arg.call_back(arg.raw_ptr);
+    }
+
+    // Asks the host to create a new thread for us.
+    // Newly created thread wil lcall `wasi_tread_start` with the thread ID as well
+    // as the input `arg` that was provided to `spawnWasiThread`
+    const spawnWasiThread = @"thread-spawn";
+    extern "wasi" fn @"thread-spawn"(arg: *const Instance) i32;
+
+    /// Initializes the TLS data segment starting at `memory`.
+    /// This is a synthetic function, generated by the linker.
+    extern fn __wasm_init_tls(memory: [*]u8) void;
+    extern fn __set_stack_pointer(ptr: [*]u8) void;
+
+    /// Returns a pointer to the base of the TLS data segment for the current thread
+    inline fn __tls_base() [*]u8 {
+        return asm (
+            \\ .globaltype __tls_base, i32
+            \\ global.get __tls_base
+            \\ local.set %[ret]
+            : [ret] "=r" (-> [*]u8),
+        );
+    }
+
+    /// Returns the size of the TLS segment
+    inline fn __tls_size() u32 {
+        return asm volatile (
+            \\ .globaltype __tls_size, i32, immutable
+            \\ global.get __tls_size
+            \\ local.set %[ret]
+            : [ret] "=r" (-> u32),
+        );
+    }
+
+    /// Returns the alignment of the TLS segment
+    inline fn __tls_align() u32 {
+        return asm (
+            \\ .globaltype __tls_align, i32, immutable
+            \\ global.get __tls_align
+            \\ local.set %[ret]
+            : [ret] "=r" (-> u32),
+        );
+    }
+};
+
 const LinuxThreadImpl = struct {
     const linux = os.linux;