master
  1//===----------------------------------------------------------------------===//
  2//
  3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4// See https://llvm.org/LICENSE.txt for license information.
  5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6//
  7//===----------------------------------------------------------------------===//
  8
  9#ifndef _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_REDUCE_H
 10#define _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_REDUCE_H
 11
 12#include <__assert>
 13#include <__config>
 14#include <__iterator/concepts.h>
 15#include <__iterator/iterator_traits.h>
 16#include <__numeric/transform_reduce.h>
 17#include <__pstl/backend_fwd.h>
 18#include <__pstl/cpu_algos/cpu_traits.h>
 19#include <__type_traits/desugars_to.h>
 20#include <__type_traits/is_arithmetic.h>
 21#include <__type_traits/is_execution_policy.h>
 22#include <__utility/move.h>
 23#include <optional>
 24
 25#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 26#  pragma GCC system_header
 27#endif
 28
 29_LIBCPP_PUSH_MACROS
 30#include <__undef_macros>
 31
 32#if _LIBCPP_STD_VER >= 17
 33
 34_LIBCPP_BEGIN_NAMESPACE_STD
 35namespace __pstl {
 36
 37template <typename _Backend,
 38          typename _DifferenceType,
 39          typename _Tp,
 40          typename _BinaryOperation,
 41          typename _UnaryOperation,
 42          typename _UnaryResult = invoke_result_t<_UnaryOperation, _DifferenceType>,
 43          __enable_if_t<__desugars_to_v<__plus_tag, _BinaryOperation, _Tp, _UnaryResult> && is_arithmetic_v<_Tp> &&
 44                            is_arithmetic_v<_UnaryResult>,
 45                        int>    = 0>
 46_LIBCPP_HIDE_FROM_ABI _Tp
 47__simd_transform_reduce(_DifferenceType __n, _Tp __init, _BinaryOperation, _UnaryOperation __f) noexcept {
 48  _PSTL_PRAGMA_SIMD_REDUCTION(+ : __init)
 49  for (_DifferenceType __i = 0; __i < __n; ++__i)
 50    __init += __f(__i);
 51  return __init;
 52}
 53
 54template <typename _Backend,
 55          typename _Size,
 56          typename _Tp,
 57          typename _BinaryOperation,
 58          typename _UnaryOperation,
 59          typename _UnaryResult = invoke_result_t<_UnaryOperation, _Size>,
 60          __enable_if_t<!(__desugars_to_v<__plus_tag, _BinaryOperation, _Tp, _UnaryResult> && is_arithmetic_v<_Tp> &&
 61                          is_arithmetic_v<_UnaryResult>),
 62                        int>    = 0>
 63_LIBCPP_HIDE_FROM_ABI _Tp
 64__simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept {
 65  constexpr size_t __lane_size = __cpu_traits<_Backend>::__lane_size;
 66  const _Size __block_size     = __lane_size / sizeof(_Tp);
 67  if (__n > 2 * __block_size && __block_size > 1) {
 68    alignas(__lane_size) char __lane_buffer[__lane_size];
 69    _Tp* __lane = reinterpret_cast<_Tp*>(__lane_buffer);
 70
 71    // initializer
 72    _PSTL_PRAGMA_SIMD
 73    for (_Size __i = 0; __i < __block_size; ++__i) {
 74      ::new (__lane + __i) _Tp(__binary_op(__f(__i), __f(__block_size + __i)));
 75    }
 76    // main loop
 77    _Size __i                    = 2 * __block_size;
 78    const _Size __last_iteration = __block_size * (__n / __block_size);
 79    for (; __i < __last_iteration; __i += __block_size) {
 80      _PSTL_PRAGMA_SIMD
 81      for (_Size __j = 0; __j < __block_size; ++__j) {
 82        __lane[__j] = __binary_op(std::move(__lane[__j]), __f(__i + __j));
 83      }
 84    }
 85    // remainder
 86    _PSTL_PRAGMA_SIMD
 87    for (_Size __j = 0; __j < __n - __last_iteration; ++__j) {
 88      __lane[__j] = __binary_op(std::move(__lane[__j]), __f(__last_iteration + __j));
 89    }
 90    // combiner
 91    for (_Size __j = 0; __j < __block_size; ++__j) {
 92      __init = __binary_op(std::move(__init), std::move(__lane[__j]));
 93    }
 94    // destroyer
 95    _PSTL_PRAGMA_SIMD
 96    for (_Size __j = 0; __j < __block_size; ++__j) {
 97      __lane[__j].~_Tp();
 98    }
 99  } else {
100    for (_Size __i = 0; __i < __n; ++__i) {
101      __init = __binary_op(std::move(__init), __f(__i));
102    }
103  }
104  return __init;
105}
106
107template <class _Backend, class _RawExecutionPolicy>
108struct __cpu_parallel_transform_reduce_binary {
109  template <class _Policy,
110            class _ForwardIterator1,
111            class _ForwardIterator2,
112            class _Tp,
113            class _BinaryOperation1,
114            class _BinaryOperation2>
115  _LIBCPP_HIDE_FROM_ABI optional<_Tp> operator()(
116      _Policy&& __policy,
117      _ForwardIterator1 __first1,
118      _ForwardIterator1 __last1,
119      _ForwardIterator2 __first2,
120      _Tp __init,
121      _BinaryOperation1 __reduce,
122      _BinaryOperation2 __transform) const noexcept {
123    if constexpr (__is_parallel_execution_policy_v<_RawExecutionPolicy> &&
124                  __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
125                  __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value) {
126      return __cpu_traits<_Backend>::__transform_reduce(
127          __first1,
128          std::move(__last1),
129          [__first1, __first2, __transform](_ForwardIterator1 __iter) {
130            return __transform(*__iter, *(__first2 + (__iter - __first1)));
131          },
132          std::move(__init),
133          std::move(__reduce),
134          [&__policy, __first1, __first2, __reduce, __transform](
135              _ForwardIterator1 __brick_first, _ForwardIterator1 __brick_last, _Tp __brick_init) {
136            using _TransformReduceBinaryUnseq =
137                __pstl::__transform_reduce_binary<_Backend, __remove_parallel_policy_t<_RawExecutionPolicy>>;
138            return *_TransformReduceBinaryUnseq()(
139                std::__remove_parallel_policy(__policy),
140                __brick_first,
141                std::move(__brick_last),
142                __first2 + (__brick_first - __first1),
143                std::move(__brick_init),
144                std::move(__reduce),
145                std::move(__transform));
146          });
147    } else if constexpr (__is_unsequenced_execution_policy_v<_RawExecutionPolicy> &&
148                         __has_random_access_iterator_category_or_concept<_ForwardIterator1>::value &&
149                         __has_random_access_iterator_category_or_concept<_ForwardIterator2>::value) {
150      return __pstl::__simd_transform_reduce<_Backend>(
151          __last1 - __first1, std::move(__init), std::move(__reduce), [&](__iter_diff_t<_ForwardIterator1> __i) {
152            return __transform(__first1[__i], __first2[__i]);
153          });
154    } else {
155      return std::transform_reduce(
156          std::move(__first1),
157          std::move(__last1),
158          std::move(__first2),
159          std::move(__init),
160          std::move(__reduce),
161          std::move(__transform));
162    }
163  }
164};
165
166template <class _Backend, class _RawExecutionPolicy>
167struct __cpu_parallel_transform_reduce {
168  template <class _Policy, class _ForwardIterator, class _Tp, class _BinaryOperation, class _UnaryOperation>
169  _LIBCPP_HIDE_FROM_ABI optional<_Tp>
170  operator()(_Policy&& __policy,
171             _ForwardIterator __first,
172             _ForwardIterator __last,
173             _Tp __init,
174             _BinaryOperation __reduce,
175             _UnaryOperation __transform) const noexcept {
176    if constexpr (__is_parallel_execution_policy_v<_RawExecutionPolicy> &&
177                  __has_random_access_iterator_category_or_concept<_ForwardIterator>::value) {
178      return __cpu_traits<_Backend>::__transform_reduce(
179          std::move(__first),
180          std::move(__last),
181          [__transform](_ForwardIterator __iter) { return __transform(*__iter); },
182          std::move(__init),
183          __reduce,
184          [&__policy, __transform, __reduce](auto __brick_first, auto __brick_last, _Tp __brick_init) {
185            using _TransformReduceUnseq =
186                __pstl::__transform_reduce<_Backend, __remove_parallel_policy_t<_RawExecutionPolicy>>;
187            auto __res = _TransformReduceUnseq()(
188                std::__remove_parallel_policy(__policy),
189                std::move(__brick_first),
190                std::move(__brick_last),
191                std::move(__brick_init),
192                std::move(__reduce),
193                std::move(__transform));
194            _LIBCPP_ASSERT_INTERNAL(__res, "unseq/seq should never try to allocate!");
195            return *std::move(__res);
196          });
197    } else if constexpr (__is_unsequenced_execution_policy_v<_RawExecutionPolicy> &&
198                         __has_random_access_iterator_category_or_concept<_ForwardIterator>::value) {
199      return __pstl::__simd_transform_reduce<_Backend>(
200          __last - __first,
201          std::move(__init),
202          std::move(__reduce),
203          [=, &__transform](__iter_diff_t<_ForwardIterator> __i) { return __transform(__first[__i]); });
204    } else {
205      return std::transform_reduce(
206          std::move(__first), std::move(__last), std::move(__init), std::move(__reduce), std::move(__transform));
207    }
208  }
209};
210
211} // namespace __pstl
212_LIBCPP_END_NAMESPACE_STD
213
214#endif // _LIBCPP_STD_VER >= 17
215
216_LIBCPP_POP_MACROS
217
218#endif // _LIBCPP___PSTL_CPU_ALGOS_TRANSFORM_REDUCE_H