Commit 37e2a04da8

Luuk de Gram <luuk@degram.dev>
2023-07-06 19:31:08
add stand alone test to verify bulk-memory features
This adds a standalone test case to ensure the runtime does not trap when performing a memory.copy or memory.fill instruction while the destination or source address is out-of-bounds and the length is 0.
1 parent d54ebf4
Changed files (4)
src
codegen
test
standalone
zerolength_check
src/codegen/llvm.zig
@@ -8514,8 +8514,8 @@ pub const FuncGen = struct {
         // Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless
         // of the length. This means we need to emit a check where we skip the memset when the length
         // is 0 as we allow for undefined pointers in 0-sized slices.
-        const needs_wasm_safety_check = safety and
-            o.target.isWasm() and
+        // This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done.
+        const intrinsic_len0_traps = o.target.isWasm() and
             ptr_ty.isSlice(mod) and
             std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory);
 
@@ -8529,7 +8529,7 @@ pub const FuncGen = struct {
                 else
                     u8_llvm_ty.getUndef();
                 const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
-                if (needs_wasm_safety_check) {
+                if (intrinsic_len0_traps) {
                     try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
                 } else {
                     _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
@@ -8552,7 +8552,7 @@ pub const FuncGen = struct {
                 });
                 const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
 
-                if (needs_wasm_safety_check) {
+                if (intrinsic_len0_traps) {
                     try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
                 } else {
                     _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
@@ -8569,7 +8569,7 @@ pub const FuncGen = struct {
             const fill_byte = try self.bitCast(value, elem_ty, Type.u8);
             const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
 
-            if (needs_wasm_safety_check) {
+            if (intrinsic_len0_traps) {
                 try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
             } else {
                 _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
@@ -8652,19 +8652,15 @@ pub const FuncGen = struct {
         dest_ptr_align: u32,
         is_volatile: bool,
     ) !void {
-        const parent_block = self.context.createBasicBlock("Block");
         const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
-        const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq);
-        const then_block = self.context.appendBasicBlock(self.llvm_func, "Then");
-        const else_block = self.context.appendBasicBlock(self.llvm_func, "Else");
-        _ = self.builder.buildCondBr(cond, then_block, else_block);
-        self.builder.positionBuilderAtEnd(then_block);
-        _ = self.builder.buildBr(parent_block);
-        self.builder.positionBuilderAtEnd(else_block);
+        const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq);
+        const memset_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapSkip");
+        const end_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapEnd");
+        _ = self.builder.buildCondBr(cond, memset_block, end_block);
+        self.builder.positionBuilderAtEnd(memset_block);
         _ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
-        _ = self.builder.buildBr(parent_block);
-        self.llvm_func.appendExistingBasicBlock(parent_block);
-        self.builder.positionBuilderAtEnd(parent_block);
+        _ = self.builder.buildBr(end_block);
+        self.builder.positionBuilderAtEnd(end_block);
     }
 
     fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
@@ -8682,24 +8678,19 @@ pub const FuncGen = struct {
 
         // When bulk-memory is enabled, this will be lowered to WebAssembly's memory.copy instruction.
         // This instruction will trap on an invalid address, regardless of the length.
-        // For this reason we must add a safety-check for 0-sized slices as its pointer field can be undefined.
+        // For this reason we must add a check for 0-sized slices as its pointer field can be undefined.
         // We only have to do this for slices as arrays will have a valid pointer.
+        // This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done.
         if (o.target.isWasm() and
             std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory) and
-            (src_ptr_ty.isSlice(mod) or dest_ptr_ty.isSlice(mod)))
+            dest_ptr_ty.isSlice(mod))
         {
-            const parent_block = self.context.createBasicBlock("Block");
-
-            const llvm_usize_ty = self.context.intType(o.target.ptrBitWidth());
-            const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .eq);
-            const then_block = self.context.appendBasicBlock(self.llvm_func, "Then");
-            const else_block = self.context.appendBasicBlock(self.llvm_func, "Else");
-            _ = self.builder.buildCondBr(cond, then_block, else_block);
-
-            self.builder.positionBuilderAtEnd(then_block);
-            _ = self.builder.buildBr(parent_block);
-
-            self.builder.positionBuilderAtEnd(else_block);
+            const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
+            const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq);
+            const memcpy_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapSkip");
+            const end_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapEnd");
+            _ = self.builder.buildCondBr(cond, memcpy_block, end_block);
+            self.builder.positionBuilderAtEnd(memcpy_block);
             _ = self.builder.buildMemCpy(
                 dest_ptr,
                 dest_ptr_ty.ptrAlignment(mod),
@@ -8708,9 +8699,8 @@ pub const FuncGen = struct {
                 len,
                 is_volatile,
             );
-            _ = self.builder.buildBr(parent_block);
-            self.llvm_func.appendExistingBasicBlock(parent_block);
-            self.builder.positionBuilderAtEnd(parent_block);
+            _ = self.builder.buildBr(end_block);
+            self.builder.positionBuilderAtEnd(end_block);
             return null;
         }
 
test/standalone/zerolength_check/src/main.zig
@@ -0,0 +1,23 @@
+const std = @import("std");
+
+test {
+    var dest = foo();
+    var source = foo();
+
+    @memcpy(dest, source);
+    @memset(dest, 4);
+    @memset(dest, undefined);
+
+    var dest2 = foo2();
+    @memset(dest2, 0);
+}
+
+fn foo() []u8 {
+    const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1);
+    return @as([*]align(1) u8, @ptrFromInt(ptr))[0..0];
+}
+
+fn foo2() []u64 {
+    const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1);
+    return @as([*]align(1) u64, @ptrFromInt(ptr))[0..0];
+}
test/standalone/zerolength_check/build.zig
@@ -0,0 +1,27 @@
+const std = @import("std");
+
+pub fn build(b: *std.Build) void {
+    const test_step = b.step("test", "Test it");
+    b.default_step = test_step;
+
+    add(b, test_step, .Debug);
+    add(b, test_step, .ReleaseFast);
+    add(b, test_step, .ReleaseSmall);
+    add(b, test_step, .ReleaseSafe);
+}
+
+fn add(b: *std.Build, test_step: *std.Build.Step, optimize: std.builtin.OptimizeMode) void {
+    const unit_tests = b.addTest(.{
+        .root_source_file = .{ .path = "src/main.zig" },
+        .target = .{
+            .os_tag = .wasi,
+            .cpu_arch = .wasm32,
+            .cpu_features_add = std.Target.wasm.featureSet(&.{.bulk_memory}),
+        },
+        .optimize = optimize,
+    });
+
+    const run_unit_tests = b.addRunArtifact(unit_tests);
+    run_unit_tests.skip_foreign_checks = true;
+    test_step.dependOn(&run_unit_tests.step);
+}
test/standalone.zig
@@ -230,6 +230,10 @@ pub const build_cases = [_]BuildCase{
         .build_root = "test/standalone/cmakedefine",
         .import = @import("standalone/cmakedefine/build.zig"),
     },
+    .{
+        .build_root = "test/standalone/zerolength_check",
+        .import = @import("standalone/zerolength_check/build.zig"),
+    },
 };
 
 const std = @import("std");