Commit 4db01eb130

dweiller <4678790+dweiller@users.noreply.github.com>
2024-12-05 13:54:51
compiler-rt: optimize memmove
1 parent 7cef585
Changed files (3)
lib/compiler_rt/common.zig
@@ -16,6 +16,17 @@ else
 pub const visibility: std.builtin.SymbolVisibility =
     if (builtin.target.isWasm() and linkage != .internal) .hidden else .default;
 
+pub const PreferredLoadStoreElement = element: {
+    if (std.simd.suggestVectorLength(u8)) |vec_size| {
+        const Vec = @Vector(vec_size, u8);
+
+        if (@sizeOf(Vec) == vec_size and std.math.isPowerOfTwo(vec_size)) {
+            break :element Vec;
+        }
+    }
+    break :element usize;
+};
+
 pub const want_aeabi = switch (builtin.abi) {
     .eabi,
     .eabihf,
lib/compiler_rt/memcpy.zig
@@ -18,16 +18,7 @@ comptime {
     }
 }
 
-const Element = Element: {
-    if (std.simd.suggestVectorLength(u8)) |vec_size| {
-        const Vec = @Vector(vec_size, u8);
-
-        if (@sizeOf(Vec) == vec_size and std.math.isPowerOfTwo(vec_size)) {
-            break :Element Vec;
-        }
-    }
-    break :Element usize;
-};
+const Element = common.PreferredLoadStoreElement;
 
 comptime {
     assert(std.math.isPowerOfTwo(@sizeOf(Element)));
lib/compiler_rt/memmove.zig
@@ -1,6 +1,10 @@
 const std = @import("std");
 const common = @import("./common.zig");
 const builtin = @import("builtin");
+const assert = std.debug.assert;
+const memcpy = @import("memcpy.zig");
+
+const Element = common.PreferredLoadStoreElement;
 
 comptime {
     if (builtin.object_format != .c) {
@@ -34,137 +38,250 @@ fn memmoveSmall(opt_dest: ?[*]u8, opt_src: ?[*]const u8, len: usize) callconv(.C
     return dest;
 }
 
-pub fn memmoveFast(opt_dest: ?[*]u8, opt_src: ?[*]const u8, len: usize) callconv(.C) ?[*]u8 {
-    // a port of https://github.com/facebook/folly/blob/1c8bc50e88804e2a7361a57cd9b551dd10f6c5fd/folly/memcpy.S
-    if (len == 0) {
-        @branchHint(.unlikely);
-        return opt_dest;
+fn memmoveFast(dest: ?[*]u8, src: ?[*]u8, len: usize) callconv(.C) ?[*]u8 {
+    @setRuntimeSafety(builtin.is_test);
+    const unroll_count = 1;
+    comptime assert(std.math.isPowerOfTwo(unroll_count));
+    const small_limit = @max(2 * @sizeOf(Element), unroll_count * @sizeOf(Element));
+
+    if (copySmallLength(small_limit, dest.?, src.?, len)) return dest;
+
+    const dest_address = @intFromPtr(dest);
+    const src_address = @intFromPtr(src);
+
+    if (src_address < dest_address and src_address + len > dest_address) {
+        copyBackwards(unroll_count, dest.?, src.?, len);
+    } else {
+        copyForwards(unroll_count, dest.?, src.?, len);
     }
 
-    const dest = opt_dest.?;
-    const src = opt_src.?;
+    return dest;
+}
 
-    if (len < 8) {
-        @branchHint(.unlikely);
-        if (len == 1) {
-            @branchHint(.unlikely);
-            dest[0] = src[0];
-        } else if (len >= 4) {
-            @branchHint(.unlikely);
-            blockCopy(dest, src, 4, len);
-        } else {
-            blockCopy(dest, src, 2, len);
-        }
-        return dest;
+inline fn copySmallLength(
+    comptime small_limit: comptime_int,
+    dest: [*]u8,
+    src: [*]const u8,
+    len: usize,
+) bool {
+    if (len < 16) {
+        copyLessThan16(dest, src, len);
+        return true;
     }
 
-    if (len > 32) {
-        @branchHint(.unlikely);
-        if (len > 256) {
-            @branchHint(.unlikely);
-            copyMove(dest, src, len);
-            return dest;
-        }
-        copyLong(dest, src, len);
-        return dest;
+    if (comptime 2 < (std.math.log2(small_limit) + 1) / 2) {
+        if (copy16ToSmallLimit(small_limit, dest, src, len)) return true;
     }
 
-    if (len > 16) {
-        @branchHint(.unlikely);
-        blockCopy(dest, src, 16, len);
-        return dest;
+    return false;
+}
+
+inline fn copyLessThan16(
+    dest: [*]u8,
+    src: [*]const u8,
+    len: usize,
+) void {
+    @setRuntimeSafety(builtin.is_test);
+    if (len < 4) {
+        if (len == 0) return;
+        const b = len / 2;
+        const d0 = src[0];
+        const db = src[b];
+        const de = src[len - 1];
+        dest[0] = d0;
+        dest[b] = db;
+        dest[len - 1] = de;
+        return;
+    }
+    copyRange4(4, dest, src, len);
+}
+
+inline fn copy16ToSmallLimit(
+    comptime small_limit: comptime_int,
+    dest: [*]u8,
+    src: [*]const u8,
+    len: usize,
+) bool {
+    @setRuntimeSafety(builtin.is_test);
+    inline for (2..(std.math.log2(small_limit) + 1) / 2 + 1) |p| {
+        const limit = 1 << (2 * p);
+        if (len < limit) {
+            copyRange4(limit / 4, dest, src, len);
+            return true;
+        }
     }
+    return false;
+}
 
-    blockCopy(dest, src, 8, len);
+/// copy `len` bytes from `src` to `dest`; `len` must be in the range
+/// `[copy_len, 4 * copy_len)`.
+inline fn copyRange4(
+    comptime copy_len: comptime_int,
+    dest: [*]u8,
+    src: [*]const u8,
+    len: usize,
+) void {
+    @setRuntimeSafety(builtin.is_test);
+    comptime assert(std.math.isPowerOfTwo(copy_len));
+    assert(len >= copy_len);
+    assert(len < 4 * copy_len);
 
-    return dest;
+    const a = len & (copy_len * 2);
+    const b = a / 2;
+
+    const last = len - copy_len;
+    const pen = last - b;
+
+    const d0 = src[0..copy_len].*;
+    const d1 = src[b..][0..copy_len].*;
+    const d2 = src[pen..][0..copy_len].*;
+    const d3 = src[last..][0..copy_len].*;
+
+    dest[0..copy_len].* = d0;
+    dest[b..][0..copy_len].* = d1;
+    dest[pen..][0..copy_len].* = d2;
+    dest[last..][0..copy_len].* = d3;
+}
+
+inline fn copyForwards(
+    comptime unroll_count: comptime_int,
+    dest: [*]u8,
+    src: [*]const u8,
+    len: usize,
+) void {
+    @setRuntimeSafety(builtin.is_test);
+    assert(len >= 2 * @sizeOf(Element));
+    assert(len >= unroll_count * @sizeOf(Element));
+
+    const head = src[0..@sizeOf(Element)].*;
+    const tail = src[len - @sizeOf(Element) ..][0..@sizeOf(Element)].*;
+    const alignment_offset = @alignOf(Element) - @intFromPtr(src) % @alignOf(Element);
+    const n = len - alignment_offset;
+    const d = dest + alignment_offset;
+    const s = src + alignment_offset;
+
+    copyBlocksAlignedSource(@ptrCast(d), @alignCast(@ptrCast(s)), n, unroll_count);
+
+    // copy last `copy_size` bytes unconditionally, since block copy
+    // methods only copy a multiple of `copy_size` bytes.
+    dest[len - @sizeOf(Element) ..][0..@sizeOf(Element)].* = tail;
+    dest[0..@sizeOf(Element)].* = head;
 }
 
-inline fn blockCopy(dest: [*]u8, src: [*]const u8, block_size: comptime_int, len: usize) void {
-    const first = @as(*align(1) const @Vector(block_size, u8), src[0..block_size]).*;
-    const second = @as(*align(1) const @Vector(block_size, u8), src[len - block_size ..][0..block_size]).*;
-    dest[0..block_size].* = first;
-    dest[len - block_size ..][0..block_size].* = second;
+inline fn copyBlocksAlignedSource(
+    dest: [*]align(1) Element,
+    src: [*]const Element,
+    max_bytes: usize,
+    comptime unroll_count: comptime_int,
+) void {
+    copyBlocks(dest, src, max_bytes, unroll_count);
 }
 
-inline fn copyLong(dest: [*]u8, src: [*]const u8, len: usize) void {
-    var array: [8]@Vector(32, u8) = undefined;
+/// Copies the largest multiple of `@sizeOf(T)` bytes from `src` to `dest`,
+/// that is less than `max_bytes` where `T` is the child type of `src` and
+/// `dest`; `max_bytes` must be at least `@sizeOf(T)`. The primary copy loop
+/// will be unrolled to perform `unroll_count` copies per iteration.
+inline fn copyBlocks(
+    dest: anytype,
+    src: anytype,
+    max_bytes: usize,
+    comptime unroll_count: comptime_int,
+) void {
+    @setRuntimeSafety(builtin.is_test);
+    comptime assert(unroll_count > 0);
 
-    inline for (.{ 64, 128, 192, 256 }, 0..) |N, i| {
-        array[i * 2] = src[(N / 2) - 32 ..][0..32].*;
-        array[(i * 2) + 1] = src[len - N / 2 ..][0..32].*;
+    const T = @typeInfo(@TypeOf(dest)).pointer.child;
+    comptime assert(T == @typeInfo(@TypeOf(src)).pointer.child);
 
-        if (len <= N) {
-            @branchHint(.unlikely);
-            for (0..i + 1) |j| {
-                dest[j * 32 ..][0..32].* = array[j * 2];
-                dest[len - ((j * 32) + 32) ..][0..32].* = array[(j * 2) + 1];
-            }
-            return;
+    const loop_count = max_bytes / (@sizeOf(T) * unroll_count);
+
+    // save tail since it can overlap with `dest `in main copy loop
+    const tail_start = (max_bytes / @sizeOf(T)) - (unroll_count - 1);
+    const st = src[tail_start..][0 .. unroll_count - 1];
+    var tail_data: [unroll_count - 1]Element = undefined;
+    inline for (&tail_data, st) |*d, s| {
+        d.* = s;
+    }
+
+    for (0..loop_count) |i| {
+        const du = dest[i * unroll_count ..][0..unroll_count];
+        const su = src[i * unroll_count ..][0..unroll_count];
+        inline for (du, su) |*d, s| {
+            d.* = s;
         }
     }
-}
 
-inline fn copyMove(dest: [*]u8, src: [*]const u8, len: usize) void {
-    if (@intFromPtr(src) >= @intFromPtr(dest)) {
-        @branchHint(.unlikely);
-        copyForward(dest, src, len);
-    } else if (@intFromPtr(src) + len > @intFromPtr(dest)) {
-        @branchHint(.unlikely);
-        overlapBwd(dest, src, len);
-    } else {
-        copyForward(dest, src, len);
+    const dt = dest[tail_start..][0 .. unroll_count - 1];
+    inline for (dt, tail_data) |*d, s| {
+        d.* = s;
     }
 }
 
-inline fn copyForward(dest: [*]u8, src: [*]const u8, len: usize) void {
-    const tail: @Vector(32, u8) = src[len - 32 ..][0..32].*;
+inline fn copyBackwards(
+    comptime unroll_count: comptime_int,
+    dest: [*]u8,
+    src: [*]const u8,
+    len: usize,
+) void {
+    const end_bytes = src[len - @sizeOf(Element) ..][0..@sizeOf(Element)].*;
+    const start_bytes = src[0..@sizeOf(Element)].*;
+
+    const tail_dest: [*]Element = @ptrFromInt(std.mem.alignForward(usize, @intFromPtr(dest), @alignOf(Element)));
+    const tail_src: [*]align(1) const Element = @ptrCast(src + (@intFromPtr(tail_dest) - @intFromPtr(dest)));
+    const tail_bytes: [unroll_count - 1]Element = tail_src[0 .. unroll_count - 1].*;
 
-    const N: usize = len & ~@as(usize, 127);
-    var i: usize = 0;
+    const d_addr: usize = std.mem.alignBackward(usize, @intFromPtr(dest) + len, @alignOf(Element));
+    const d: [*]Element = @ptrFromInt(d_addr);
+    const n = d_addr - @intFromPtr(dest);
+    const s: [*]align(1) const Element = @ptrCast(src + n);
 
-    while (i < N) : (i += 128) {
-        dest[i..][0..32].* = src[i..][0..32].*;
-        dest[i + 32 ..][0..32].* = src[i + 32 ..][0..32].*;
-        dest[i + 64 ..][0..32].* = src[i + 64 ..][0..32].*;
-        dest[i + 96 ..][0..32].* = src[i + 96 ..][0..32].*;
+    const loop_count = n / (unroll_count * @sizeOf(Element));
+    var i: usize = 1;
+    while (i < loop_count + 1) : (i += 1) {
+        const du = d - (i * unroll_count);
+        const su = s - (i * unroll_count);
+        inline for (0..unroll_count) |j| {
+            du[unroll_count - 1 - j] = su[unroll_count - 1 - j];
+        }
     }
 
-    if (len - i <= 32) {
-        @branchHint(.unlikely);
-        dest[len - 32 ..][0..32].* = tail;
-    } else {
-        copyLong(dest[i..], src[i..], len - i);
+    inline for (tail_dest[0 .. unroll_count - 1], tail_bytes) |*dt, st| {
+        dt.* = st;
     }
+    dest[0..@sizeOf(Element)].* = start_bytes;
+
+    dest[len - @sizeOf(Element) ..][0..@sizeOf(Element)].* = end_bytes;
 }
 
-inline fn overlapBwd(dest: [*]u8, src: [*]const u8, len: usize) void {
-    var array: [5]@Vector(32, u8) = undefined;
-    array[0] = src[len - 32 ..][0..32].*;
-    inline for (1..5) |i| array[i] = src[(i - 1) << 5 ..][0..32].*;
-
-    const end: usize = (@intFromPtr(dest) + len - 32) & 31;
-    const range = len - end;
-    var s = src + range;
-    var d = dest + range;
-
-    while (@intFromPtr(s) > @intFromPtr(src + 128)) {
-        // zig fmt: off
-        const first  = @as(*align(1) const @Vector(32, u8), @ptrCast(s - 32)).*;
-        const second = @as(*align(1) const @Vector(32, u8), @ptrCast(s - 64)).*;
-        const third  = @as(*align(1) const @Vector(32, u8), @ptrCast(s - 96)).*;
-        const fourth = @as(*align(1) const @Vector(32, u8), @ptrCast(s - 128)).*;
-
-        @as(*align(32) @Vector(32, u8), @alignCast(@ptrCast(d - 32))).*  = first;
-        @as(*align(32) @Vector(32, u8), @alignCast(@ptrCast(d - 64))).*  = second;
-        @as(*align(32) @Vector(32, u8), @alignCast(@ptrCast(d - 96))).*  = third;
-        @as(*align(32) @Vector(32, u8), @alignCast(@ptrCast(d - 128))).* = fourth;
-        // zig fmt: on
-
-        s -= 128;
-        d -= 128;
+test memmoveFast {
+    const max_len = 1024;
+    var buffer: [max_len + @alignOf(Element) - 1]u8 = undefined;
+    for (&buffer, 0..) |*b, i| {
+        b.* = @intCast(i % 97);
     }
 
-    inline for (array[1..], 0..) |vec, i| dest[i * 32 ..][0..32].* = vec;
-    dest[len - 32 ..][0..32].* = array[0];
+    var move_buffer: [max_len + @alignOf(Element) - 1]u8 align(@alignOf(Element)) = undefined;
+
+    for (0..max_len) |copy_len| {
+        for (0..@alignOf(Element)) |s_offset| {
+            for (0..@alignOf(Element)) |d_offset| {
+                for (&move_buffer, buffer) |*d, s| {
+                    d.* = s;
+                }
+                const dest = move_buffer[d_offset..][0..copy_len];
+                const src = move_buffer[s_offset..][0..copy_len];
+                _ = memmoveFast(dest.ptr, src.ptr, copy_len);
+                std.testing.expectEqualSlices(u8, buffer[s_offset..][0..copy_len], dest) catch |e| {
+                    std.debug.print(
+                        "error occured with source offset {d} and destination offset {d}\n",
+                        .{
+                            s_offset,
+                            d_offset,
+                        },
+                    );
+                    return e;
+                };
+            }
+        }
+    }
 }