vir-simd 0.4.189
Parallelism TS 2 extensions and simd fallback implementation
Loading...
Searching...
No Matches
simd_permute.h
Go to the documentation of this file.
1/* SPDX-License-Identifier: LGPL-3.0-or-later */
2/* Copyright © 2023–2024 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH
3 * Matthias Kretz <m.kretz@gsi.de>
4 */
5
6// Implements non-members of P2664
7
8#ifndef VIR_SIMD_PERMUTE_H_
9#define VIR_SIMD_PERMUTE_H_
10
14
15#include "simd_concepts.h"
16
17#if VIR_HAVE_SIMD_CONCEPTS
18#define VIR_HAVE_SIMD_PERMUTE 1
19
20#include "simd.h"
21#include "detail.h"
22#include "constexpr_wrapper.h"
23#include <bit>
24
25namespace vir
26{
27 namespace detail
28 {
29 template <typename F>
30 concept index_permutation_function_nosize = requires(F const& f)
31 {
32 { f(vir::cw<0>) } -> std::integral;
33#if __GNUC__ >= 15
34 typename std::integral_constant<int, f(vir::cw<0>)>;
35#endif
36 };
37
38 template <typename F, std::size_t Size>
39 concept index_permutation_function_size = requires(F const& f)
40 {
41 { f(vir::cw<0>, vir::cw<Size>) } -> std::integral;
42#if __GNUC__ >= 15
43 typename std::integral_constant<int, f(vir::cw<0>, vir::cw<Size>)>;
44#endif
45 };
46
47 template <typename F, std::size_t Size>
48 concept index_permutation_function
49 = index_permutation_function_size<F, Size> or index_permutation_function_nosize<F>;
50 }
51
53 constexpr int simd_permute_zero = std::numeric_limits<int>::max();
54
57
58#if defined __clang__ and __clang__ <= 13
59#define VIR_CONSTEVAL constexpr
60#else
61#define VIR_CONSTEVAL consteval
62#endif
63
66 {
67 struct DuplicateEven
68 {
69 VIR_CONSTEVAL unsigned
70 operator()(unsigned i) const
71 { return i & ~1u; }
72 };
73
75 inline constexpr DuplicateEven duplicate_even {};
76
77 struct DuplicateOdd
78 {
79 VIR_CONSTEVAL unsigned
80 operator()(unsigned i) const
81 { return i | 1u; }
82 };
83
85 inline constexpr DuplicateOdd duplicate_odd {};
86
87 template <unsigned N>
88 struct SwapNeighbors
89 {
90 VIR_CONSTEVAL unsigned
91 operator()(unsigned i, auto size) const
92 {
93 static_assert(size % (2 * N) == 0,
94 "swap_neighbors<N> permutation requires a multiple of 2N elements");
95 if (std::has_single_bit(N))
96 return i ^ N;
97 else if (i % (2 * N) >= N)
98 return i - N;
99 else
100 return i + N;
101 }
102 };
103
109 template <unsigned N = 1u>
110 inline constexpr SwapNeighbors<N> swap_neighbors {};
111
112 template <int Position>
113 struct Broadcast
114 {
115 VIR_CONSTEVAL int
116 operator()(int) const
117 { return Position; }
118 };
119
121 template <int Position>
122 inline constexpr Broadcast<Position> broadcast {};
123
125 inline constexpr Broadcast<0> broadcast_first {};
126
128 inline constexpr Broadcast<-1> broadcast_last {};
129
130 struct Reverse
131 {
132 VIR_CONSTEVAL int
133 operator()(int i) const
134 { return -1 - i; }
135 };
136
138 inline constexpr Reverse reverse {};
139
140 template <int O>
141 struct Rotate
142 {
143 static constexpr int Offset = O;
144 static constexpr bool is_even_rotation = Offset % 2 == 0;
145
146 VIR_CONSTEVAL int
147 operator()(int i, auto size) const
148 { return (i + Offset) % size.value; }
149 };
150
152 template <int Offset>
153 inline constexpr Rotate<Offset> rotate {};
154
155 template <int Offset>
156 struct Shift
157 {
158 VIR_CONSTEVAL int
159 operator()(int i, int size) const
160 {
161 const int j = i + Offset;
162 if constexpr (Offset >= 0)
163 return j >= size ? simd_permute_zero : j;
164 else
165 return j < 0 ? simd_permute_zero : j;
166 }
167 };
168
170 template <int Offset>
171 inline constexpr Shift<Offset> shift {};
172 }
173
174#undef VIR_CONSTEVAL
175
178 template <std::size_t N = 0, vir::any_simd_or_mask V,
179 detail::index_permutation_function<V::size()> F>
180 VIR_ALWAYS_INLINE constexpr stdx::resize_simd_t<N == 0 ? V::size() : N, V>
181 simd_permute(V const& v, F const idx_perm) noexcept
182 {
183 using T = typename V::value_type;
184 using R = stdx::resize_simd_t<N == 0 ? V::size() : N, V>;
185
186#if defined __GNUC__
187 if (not std::is_constant_evaluated())
188 if constexpr (std::has_single_bit(sizeof(V)) and V::size() <= stdx::native_simd<T>::size())
189 {
190#if defined __AVX2__
191 using v4df [[gnu::vector_size(32)]] = double;
192 if constexpr (std::same_as<T, float> and std::is_trivially_copyable_v<V>
193 and sizeof(v4df) == sizeof(V)
194 and requires {
195 F::is_even_rotation;
196 F::Offset;
197 { std::bool_constant<F::is_even_rotation>() }
198 -> std::same_as<std::true_type>;
199 })
200 {
201 const v4df intrin = detail::bit_cast<v4df>(v);
202 constexpr int control = ((F::Offset / 2) << 0)
203 | (((F::Offset / 2 + 1) % 4) << 2)
204 | (((F::Offset / 2 + 2) % 4) << 4)
205 | (((F::Offset / 2 + 3) % 4) << 6);
206 return detail::bit_cast<R>(__builtin_ia32_permdf256(intrin, control));
207 }
208#endif
209#if VIR_HAVE_WORKING_SHUFFLEVECTOR
210 if constexpr (std::has_single_bit(sizeof(V)) and std::has_single_bit(sizeof(R)))
211 {
212 using VBuiltin [[gnu::vector_size(sizeof(V))]] = T;
213 using RBuiltin [[gnu::vector_size(sizeof(R))]] = T;
214 if constexpr (std::is_trivially_copyable_v<V> and std::is_trivially_copyable_v<R>
215 and sizeof(VBuiltin) == sizeof(V) and sizeof(RBuiltin) == sizeof(R))
216 {
217 const VBuiltin vec = detail::bit_cast<VBuiltin>(v);
218 constexpr auto idx_perm2 = [=](constexpr_value auto i) {
219 if constexpr (detail::index_permutation_function_nosize<F>)
220 return vir::cw<idx_perm(i)>;
221 else
222 return vir::cw<idx_perm(i, vir::cw<V::size()>)>;
223 };
224 constexpr auto adj_idx = [](constexpr_value auto i) {
225 constexpr int j = i;
226 if constexpr (j == simd_permute_zero)
227 return vir::cw<V::size()>;
228 else if constexpr (j == simd_permute_uninit)
229 return vir::cw<-1>;
230 else if constexpr (j < 0)
231 {
232 static_assert (-j <= int(V::size()));
233 return vir::cw<int(V::size()) + j>;
234 }
235 else
236 {
237 static_assert (j < int(V::size()));
238 return vir::cw<j>;
239 }
240 };
241 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
242 return detail::bit_cast<R>(
243 __builtin_shufflevector(vec, VBuiltin{},
244 adj_idx(idx_perm2(vir::cw<Is>)).value...));
245 }(std::make_index_sequence<R::size()>());
246 }
247 }
248#endif
249 }
250#endif // __GNUC__
251
252 return R([&](auto i) -> T {
253 constexpr int j = [&] {
254 if constexpr (detail::index_permutation_function_nosize<F>)
255 return idx_perm(i);
256 else
257 return idx_perm(i, vir::cw<V::size()>);
258 }();
259 if constexpr (j == simd_permute_zero)
260 return 0;
261 else if constexpr (j == simd_permute_uninit)
262 {
263 T uninit;
264 return uninit;
265 }
266 else if constexpr (j < 0)
267 {
268 static_assert(-j <= int(V::size()));
269 return v[v.size() + j];
270 }
271 else
272 {
273 static_assert(j < int(V::size()));
274 return v[j];
275 }
276 });
277 }
278
280 template <std::size_t N = 0, vir::vectorizable T, detail::index_permutation_function<1> F>
281 VIR_ALWAYS_INLINE constexpr
282 std::conditional_t<N <= 1, T, stdx::resize_simd_t<N == 0 ? 1 : N, stdx::simd<T>>>
283 simd_permute(T const& v, F const idx_perm) noexcept
284 {
285 if constexpr (N <= 1)
286 {
287 constexpr auto i = vir::cw<0>;
288 constexpr int j = [&] {
289 if constexpr (detail::index_permutation_function_nosize<F>)
290 return idx_perm(i);
291 else
292 return idx_perm(i, vir::cw<std::size_t(1)>);
293 }();
294 if constexpr (j == simd_permute_zero)
295 return 0;
296 else if constexpr (j == simd_permute_uninit)
297 {
298 T uninit;
299 return uninit;
300 }
301 else
302 {
303 static_assert(j == 0 or j == -1);
304 return v;
305 }
306 }
307 else
308 return simd_permute<N>(stdx::simd<T, stdx::simd_abi::scalar>(v), idx_perm);
309 }
310
312 template <int Offset, vir::any_simd_or_mask V>
313 VIR_ALWAYS_INLINE constexpr V
314 simd_shift_in(V const& a, std::convertible_to<V> auto const&... more) noexcept
315 {
316 return V([&](auto i) -> typename V::value_type {
317 constexpr int ninputs = 1 + sizeof...(more);
318 constexpr int w = V::size();
319 constexpr int j = Offset + int(i);
320 if constexpr (j >= w * ninputs)
321 return 0;
322 else if constexpr (j >= 0)
323 {
324 const V tmp[] = {a, more...};
325 return tmp[j / w][j % w];
326 }
327 else if constexpr (j < -w)
328 return 0;
329 else
330 return a[w + j];
331 });
332 }
333}
334
335#endif // has concepts
336#endif // VIR_SIMD_PERMUTE_H_
337
338// vim: noet cc=101 tw=100 sw=2 ts=8
Satisfied if V is either a simd or a simd_mask.
Definition simd_concepts.h:61
Predefined permutations.
Definition simd_permute.h:66
constexpr Reverse reverse
Reverse the elements.
Definition simd_permute.h:138
constexpr SwapNeighbors< N > swap_neighbors
Swaps N neighboring elements.
Definition simd_permute.h:110
constexpr DuplicateOdd duplicate_odd
Copies odd elements into even elements.
Definition simd_permute.h:85
constexpr Shift< Offset > shift
Shift the elements by Offset.
Definition simd_permute.h:171
constexpr Broadcast<-1 > broadcast_last
Copy the last element into all elements.
Definition simd_permute.h:128
constexpr Broadcast< Position > broadcast
Copy element at index Position into all elements.
Definition simd_permute.h:122
constexpr DuplicateEven duplicate_even
Copies even elements into odd elements.
Definition simd_permute.h:75
constexpr Broadcast< 0 > broadcast_first
Copy the first element into all elements.
Definition simd_permute.h:125
constexpr Rotate< Offset > rotate
Rotate the elements by Offset.
Definition simd_permute.h:153
This namespace collects libraries and tools authored by Matthias Kretz.
Definition constexpr_wrapper.h:21
constexpr int simd_permute_zero
Constant that requests a zero value instead of one of the input values.
Definition simd_permute.h:53
constexpr stdx::resize_simd_t< N==0 ? V::size() :N, V > simd_permute(V const &v, F const idx_perm) noexcept
Permute the elements of v using the index permutation function idx_perm.
Definition simd_permute.h:181
constexpr int simd_permute_uninit
Constant that allows an arbitrary value instead of one of the input values.
Definition simd_permute.h:56
constexpr V simd_shift_in(V const &a, std::convertible_to< V > auto const &... more) noexcept
Concatenate a, more..., shift by Offset, and return the first V::size() elements.
Definition simd_permute.h:314
C++20 concepts extending the Parallelism TS 2 (which is limited to C++17).