Commit d54ebf4356

Luuk de Gram <luuk@degram.dev>
2023-07-04 20:12:48
llvm: add safety-check for Wasm memset
When lowering the `memset` instruction, LLVM will lower it to WebAssembly's `memory.fill` instruction when the bulk-memory feature is enabled. This instruction will trap when the destination address is out-of-bounds. By Zig's semantics, it is valid to have an invalid pointer when the length is 0. To prevent runtimes from trapping, we add a safety-check for slices to only lower to a memset instruction when the length is larger than 0.
1 parent 836f9fc
Changed files (1)
src
codegen
src/codegen/llvm.zig
@@ -8511,6 +8511,14 @@ pub const FuncGen = struct {
         const dest_ptr = self.sliceOrArrayPtr(dest_slice, ptr_ty);
         const is_volatile = ptr_ty.isVolatilePtr(mod);
 
+        // Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless
+        // of the length. This means we need to emit a check where we skip the memset when the length
+        // is 0 as we allow for undefined pointers in 0-sized slices.
+        const needs_wasm_safety_check = safety and
+            o.target.isWasm() and
+            ptr_ty.isSlice(mod) and
+            std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory);
+
         if (try self.air.value(bin_op.rhs, mod)) |elem_val| {
             if (elem_val.isUndefDeep(mod)) {
                 // Even if safety is disabled, we still emit a memset to undefined since it conveys
@@ -8521,7 +8529,11 @@ pub const FuncGen = struct {
                 else
                     u8_llvm_ty.getUndef();
                 const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
-                _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+                if (needs_wasm_safety_check) {
+                    try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+                } else {
+                    _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+                }
 
                 if (safety and mod.comp.bin_file.options.valgrind) {
                     self.valgrindMarkUndef(dest_ptr, len);
@@ -8539,7 +8551,12 @@ pub const FuncGen = struct {
                     .val = byte_val,
                 });
                 const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
-                _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+
+                if (needs_wasm_safety_check) {
+                    try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+                } else {
+                    _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+                }
                 return null;
             }
         }
@@ -8551,7 +8568,12 @@ pub const FuncGen = struct {
             // In this case we can take advantage of LLVM's intrinsic.
             const fill_byte = try self.bitCast(value, elem_ty, Type.u8);
             const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
-            _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+
+            if (needs_wasm_safety_check) {
+                try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+            } else {
+                _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+            }
             return null;
         }
 
@@ -8622,6 +8644,29 @@ pub const FuncGen = struct {
         return null;
     }
 
+    fn safeWasmMemset(
+        self: *FuncGen,
+        dest_ptr: *llvm.Value,
+        fill_byte: *llvm.Value,
+        len: *llvm.Value,
+        dest_ptr_align: u32,
+        is_volatile: bool,
+    ) !void {
+        const parent_block = self.context.createBasicBlock("Block");
+        const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
+        const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq);
+        const then_block = self.context.appendBasicBlock(self.llvm_func, "Then");
+        const else_block = self.context.appendBasicBlock(self.llvm_func, "Else");
+        _ = self.builder.buildCondBr(cond, then_block, else_block);
+        self.builder.positionBuilderAtEnd(then_block);
+        _ = self.builder.buildBr(parent_block);
+        self.builder.positionBuilderAtEnd(else_block);
+        _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
+        _ = self.builder.buildBr(parent_block);
+        self.llvm_func.appendExistingBasicBlock(parent_block);
+        self.builder.positionBuilderAtEnd(parent_block);
+    }
+
     fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
         const o = self.dg.object;
         const mod = o.module;