master
  1/*===--------- avx512vlbf16intrin.h - AVX512_BF16 intrinsics ---------------===
  2 *
  3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4 * See https://llvm.org/LICENSE.txt for license information.
  5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6 *
  7 *===-----------------------------------------------------------------------===
  8 */
  9#ifndef __IMMINTRIN_H
 10#error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
 11#endif
 12
 13#ifdef __SSE2__
 14
 15#ifndef __AVX512VLBF16INTRIN_H
 16#define __AVX512VLBF16INTRIN_H
 17
 18#define __DEFAULT_FN_ATTRS128                                                  \
 19  __attribute__((__always_inline__, __nodebug__,                               \
 20                 __target__("avx512vl,avx512bf16,no-evex512"),                 \
 21                 __min_vector_width__(128)))
 22#define __DEFAULT_FN_ATTRS256                                                  \
 23  __attribute__((__always_inline__, __nodebug__,                               \
 24                 __target__("avx512vl,avx512bf16,no-evex512"),                 \
 25                 __min_vector_width__(256)))
 26
 27/// Convert Two Packed Single Data to One Packed BF16 Data.
 28///
 29/// \headerfile <x86intrin.h>
 30///
 31/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
 32///
 33/// \param __A
 34///    A 128-bit vector of [4 x float].
 35/// \param __B
 36///    A 128-bit vector of [4 x float].
 37/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
 38///    conversion of __B, and higher 64 bits come from conversion of __A.
 39static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 40_mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
 41  return (__m128bh)__builtin_ia32_cvtne2ps2bf16_128((__v4sf) __A,
 42                                                    (__v4sf) __B);
 43}
 44
 45/// Convert Two Packed Single Data to One Packed BF16 Data.
 46///
 47/// \headerfile <x86intrin.h>
 48///
 49/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
 50///
 51/// \param __A
 52///    A 128-bit vector of [4 x float].
 53/// \param __B
 54///    A 128-bit vector of [4 x float].
 55/// \param __W
 56///    A 128-bit vector of [8 x bfloat].
 57/// \param __U
 58///    A 8-bit mask value specifying what is chosen for each element.
 59///    A 1 means conversion of __A or __B. A 0 means element from __W.
 60/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
 61///    conversion of __B, and higher 64 bits come from conversion of __A.
 62static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 63_mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
 64  return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
 65                                             (__v8bf)_mm_cvtne2ps_pbh(__A, __B),
 66                                             (__v8bf)__W);
 67}
 68
 69/// Convert Two Packed Single Data to One Packed BF16 Data.
 70///
 71/// \headerfile <x86intrin.h>
 72///
 73/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
 74///
 75/// \param __A
 76///    A 128-bit vector of [4 x float].
 77/// \param __B
 78///    A 128-bit vector of [4 x float].
 79/// \param __U
 80///    A 8-bit mask value specifying what is chosen for each element.
 81///    A 1 means conversion of __A or __B. A 0 means element is zero.
 82/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
 83///    conversion of __B, and higher 64 bits come from conversion of __A.
 84static __inline__ __m128bh __DEFAULT_FN_ATTRS128
 85_mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
 86  return (__m128bh)__builtin_ia32_selectpbf_128((__mmask8)__U,
 87                                             (__v8bf)_mm_cvtne2ps_pbh(__A, __B),
 88                                             (__v8bf)_mm_setzero_si128());
 89}
 90
 91/// Convert Two Packed Single Data to One Packed BF16 Data.
 92///
 93/// \headerfile <x86intrin.h>
 94///
 95/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
 96///
 97/// \param __A
 98///    A 256-bit vector of [8 x float].
 99/// \param __B
100///    A 256-bit vector of [8 x float].
101/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
102///    conversion of __B, and higher 128 bits come from conversion of __A.
103static __inline__ __m256bh __DEFAULT_FN_ATTRS256
104_mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
105  return (__m256bh)__builtin_ia32_cvtne2ps2bf16_256((__v8sf) __A,
106                                                    (__v8sf) __B);
107}
108
109/// Convert Two Packed Single Data to One Packed BF16 Data.
110///
111/// \headerfile <x86intrin.h>
112///
113/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
114///
115/// \param __A
116///    A 256-bit vector of [8 x float].
117/// \param __B
118///    A 256-bit vector of [8 x float].
119/// \param __W
120///    A 256-bit vector of [16 x bfloat].
121/// \param __U
122///    A 16-bit mask value specifying what is chosen for each element.
123///    A 1 means conversion of __A or __B. A 0 means element from __W.
124/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
125///    conversion of __B, and higher 128 bits come from conversion of __A.
126static __inline__ __m256bh __DEFAULT_FN_ATTRS256
127_mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
128  return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
129                                         (__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
130                                         (__v16bf)__W);
131}
132
133/// Convert Two Packed Single Data to One Packed BF16 Data.
134///
135/// \headerfile <x86intrin.h>
136///
137/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
138///
139/// \param __A
140///    A 256-bit vector of [8 x float].
141/// \param __B
142///    A 256-bit vector of [8 x float].
143/// \param __U
144///    A 16-bit mask value specifying what is chosen for each element.
145///    A 1 means conversion of __A or __B. A 0 means element is zero.
146/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
147///    conversion of __B, and higher 128 bits come from conversion of __A.
148static __inline__ __m256bh __DEFAULT_FN_ATTRS256
149_mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
150  return (__m256bh)__builtin_ia32_selectpbf_256((__mmask16)__U,
151                                         (__v16bf)_mm256_cvtne2ps_pbh(__A, __B),
152                                         (__v16bf)_mm256_setzero_si256());
153}
154
155/// Convert Packed Single Data to Packed BF16 Data.
156///
157/// \headerfile <x86intrin.h>
158///
159/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
160///
161/// \param __A
162///    A 128-bit vector of [4 x float].
163/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
164///    conversion of __A, and higher 64 bits are 0.
165#define _mm_cvtneps_pbh(A)                                                     \
166  ((__m128bh)__builtin_ia32_vcvtneps2bf16128((__v4sf)(A)))
167
168/// Convert Packed Single Data to Packed BF16 Data.
169///
170/// \headerfile <x86intrin.h>
171///
172/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
173///
174/// \param __A
175///    A 128-bit vector of [4 x float].
176/// \param __W
177///    A 128-bit vector of [8 x bfloat].
178/// \param __U
179///    A 4-bit mask value specifying what is chosen for each element.
180///    A 1 means conversion of __A. A 0 means element from __W.
181/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
182///    conversion of __A, and higher 64 bits are 0.
183static __inline__ __m128bh __DEFAULT_FN_ATTRS128
184_mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
185  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
186                                                        (__v8bf)__W,
187                                                        (__mmask8)__U);
188}
189
190/// Convert Packed Single Data to Packed BF16 Data.
191///
192/// \headerfile <x86intrin.h>
193///
194/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
195///
196/// \param __A
197///    A 128-bit vector of [4 x float].
198/// \param __U
199///    A 4-bit mask value specifying what is chosen for each element.
200///    A 1 means conversion of __A. A 0 means element is zero.
201/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
202///    conversion of __A, and higher 64 bits are 0.
203static __inline__ __m128bh __DEFAULT_FN_ATTRS128
204_mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
205  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
206                                                    (__v8bf)_mm_setzero_si128(),
207                                                    (__mmask8)__U);
208}
209
210/// Convert Packed Single Data to Packed BF16 Data.
211///
212/// \headerfile <x86intrin.h>
213///
214/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
215///
216/// \param __A
217///    A 256-bit vector of [8 x float].
218/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
219#define _mm256_cvtneps_pbh(A)                                                  \
220  ((__m128bh)__builtin_ia32_vcvtneps2bf16256((__v8sf)(A)))
221
222/// Convert Packed Single Data to Packed BF16 Data.
223///
224/// \headerfile <x86intrin.h>
225///
226/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
227///
228/// \param __A
229///    A 256-bit vector of [8 x float].
230/// \param __W
231///    A 256-bit vector of [8 x bfloat].
232/// \param __U
233///    A 8-bit mask value specifying what is chosen for each element.
234///    A 1 means conversion of __A. A 0 means element from __W.
235/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
236static __inline__ __m128bh __DEFAULT_FN_ATTRS256
237_mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
238  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
239                                                        (__v8bf)__W,
240                                                        (__mmask8)__U);
241}
242
243/// Convert Packed Single Data to Packed BF16 Data.
244///
245/// \headerfile <x86intrin.h>
246///
247/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
248///
249/// \param __A
250///    A 256-bit vector of [8 x float].
251/// \param __U
252///    A 8-bit mask value specifying what is chosen for each element.
253///    A 1 means conversion of __A. A 0 means element is zero.
254/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
255static __inline__ __m128bh __DEFAULT_FN_ATTRS256
256_mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
257  return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
258                                                    (__v8bf)_mm_setzero_si128(),
259                                                    (__mmask8)__U);
260}
261
262/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
263///
264/// \headerfile <x86intrin.h>
265///
266/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
267///
268/// \param __A
269///    A 128-bit vector of [8 x bfloat].
270/// \param __B
271///    A 128-bit vector of [8 x bfloat].
272/// \param __D
273///    A 128-bit vector of [4 x float].
274/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
275///  __A, __B and __D
276static __inline__ __m128 __DEFAULT_FN_ATTRS128
277_mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
278  return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
279                                             (__v8bf)__A,
280                                             (__v8bf)__B);
281}
282
283/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
284///
285/// \headerfile <x86intrin.h>
286///
287/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
288///
289/// \param __A
290///    A 128-bit vector of [8 x bfloat].
291/// \param __B
292///    A 128-bit vector of [8 x bfloat].
293/// \param __D
294///    A 128-bit vector of [4 x float].
295/// \param __U
296///    A 8-bit mask value specifying what is chosen for each element.
297///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
298/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
299///  __A, __B and __D
300static __inline__ __m128 __DEFAULT_FN_ATTRS128
301_mm_mask_dpbf16_ps(__m128 __D, __mmask8 __U, __m128bh __A, __m128bh __B) {
302  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
303                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
304                                           (__v4sf)__D);
305}
306
307/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
308///
309/// \headerfile <x86intrin.h>
310///
311/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
312///
313/// \param __A
314///    A 128-bit vector of [8 x bfloat].
315/// \param __B
316///    A 128-bit vector of [8 x bfloat].
317/// \param __D
318///    A 128-bit vector of [4 x float].
319/// \param __U
320///    A 8-bit mask value specifying what is chosen for each element.
321///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
322/// \returns A 128-bit vector of [4 x float] comes from  Dot Product of
323///  __A, __B and __D
324static __inline__ __m128 __DEFAULT_FN_ATTRS128
325_mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
326  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
327                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
328                                           (__v4sf)_mm_setzero_si128());
329}
330
331/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
332///
333/// \headerfile <x86intrin.h>
334///
335/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
336///
337/// \param __A
338///    A 256-bit vector of [16 x bfloat].
339/// \param __B
340///    A 256-bit vector of [16 x bfloat].
341/// \param __D
342///    A 256-bit vector of [8 x float].
343/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
344///  __A, __B and __D
345static __inline__ __m256 __DEFAULT_FN_ATTRS256
346_mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
347  return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
348                                             (__v16bf)__A,
349                                             (__v16bf)__B);
350}
351
352/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
353///
354/// \headerfile <x86intrin.h>
355///
356/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
357///
358/// \param __A
359///    A 256-bit vector of [16 x bfloat].
360/// \param __B
361///    A 256-bit vector of [16 x bfloat].
362/// \param __D
363///    A 256-bit vector of [8 x float].
364/// \param __U
365///    A 16-bit mask value specifying what is chosen for each element.
366///    A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
367/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
368///  __A, __B and __D
369static __inline__ __m256 __DEFAULT_FN_ATTRS256
370_mm256_mask_dpbf16_ps(__m256 __D, __mmask8 __U, __m256bh __A, __m256bh __B) {
371  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
372                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
373                                        (__v8sf)__D);
374}
375
376/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
377///
378/// \headerfile <x86intrin.h>
379///
380/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
381///
382/// \param __A
383///    A 256-bit vector of [16 x bfloat].
384/// \param __B
385///    A 256-bit vector of [16 x bfloat].
386/// \param __D
387///    A 256-bit vector of [8 x float].
388/// \param __U
389///    A 8-bit mask value specifying what is chosen for each element.
390///    A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
391/// \returns A 256-bit vector of [8 x float] comes from  Dot Product of
392///  __A, __B and __D
393static __inline__ __m256 __DEFAULT_FN_ATTRS256
394_mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
395  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
396                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
397                                        (__v8sf)_mm256_setzero_si256());
398}
399
400/// Convert One Single float Data to One BF16 Data.
401///
402/// \headerfile <x86intrin.h>
403///
404/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
405///
406/// \param __A
407///    A float data.
408/// \returns A bf16 data whose sign field and exponent field keep unchanged,
409///    and fraction field is truncated to 7 bits.
410static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
411  __v4sf __V = {__A, 0, 0, 0};
412  __v8bf __R = __builtin_ia32_cvtneps2bf16_128_mask(
413      (__v4sf)__V, (__v8bf)_mm_undefined_si128(), (__mmask8)-1);
414  return (__bf16)__R[0];
415}
416
417/// Convert Packed BF16 Data to Packed float Data.
418///
419/// \headerfile <x86intrin.h>
420///
421/// \param __A
422///    A 128-bit vector of [4 x bfloat].
423/// \returns A 128-bit vector of [4 x float] come from conversion of __A
424static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
425  return _mm_castsi128_ps(
426      (__m128i)_mm_slli_epi32((__m128i)_mm_cvtepi16_epi32((__m128i)__A), 16));
427}
428
429/// Convert Packed BF16 Data to Packed float Data.
430///
431/// \headerfile <x86intrin.h>
432///
433/// \param __A
434///    A 128-bit vector of [8 x bfloat].
435/// \returns A 256-bit vector of [8 x float] come from conversion of __A
436static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
437  return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
438      (__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
439}
440
441/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
442///
443/// \headerfile <x86intrin.h>
444///
445/// \param __U
446///    A 4-bit mask. Elements are zeroed out when the corresponding mask
447///    bit is not set.
448/// \param __A
449///    A 128-bit vector of [4 x bfloat].
450/// \returns A 128-bit vector of [4 x float] come from conversion of __A
451static __inline__ __m128 __DEFAULT_FN_ATTRS128
452_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
453  return _mm_castsi128_ps((__m128i)_mm_slli_epi32(
454      (__m128i)_mm_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
455}
456
457/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
458///
459/// \headerfile <x86intrin.h>
460///
461/// \param __U
462///    A 8-bit mask. Elements are zeroed out when the corresponding mask
463///    bit is not set.
464/// \param __A
465///    A 128-bit vector of [8 x bfloat].
466/// \returns A 256-bit vector of [8 x float] come from conversion of __A
467static __inline__ __m256 __DEFAULT_FN_ATTRS256
468_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
469  return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
470      (__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
471}
472
473/// Convert Packed BF16 Data to Packed float Data using merging mask.
474///
475/// \headerfile <x86intrin.h>
476///
477/// \param __S
478///    A 128-bit vector of [4 x float]. Elements are copied from __S when
479///     the corresponding mask bit is not set.
480/// \param __U
481///    A 4-bit mask. Elements are zeroed out when the corresponding mask
482///    bit is not set.
483/// \param __A
484///    A 128-bit vector of [4 x bfloat].
485/// \returns A 128-bit vector of [4 x float] come from conversion of __A
486static __inline__ __m128 __DEFAULT_FN_ATTRS128
487_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
488  return _mm_castsi128_ps((__m128i)_mm_mask_slli_epi32(
489      (__m128i)__S, (__mmask8)__U, (__m128i)_mm_cvtepi16_epi32((__m128i)__A),
490      16));
491}
492
493/// Convert Packed BF16 Data to Packed float Data using merging mask.
494///
495/// \headerfile <x86intrin.h>
496///
497/// \param __S
498///    A 256-bit vector of [8 x float]. Elements are copied from __S when
499///     the corresponding mask bit is not set.
500/// \param __U
501///    A 8-bit mask. Elements are zeroed out when the corresponding mask
502///    bit is not set.
503/// \param __A
504///    A 128-bit vector of [8 x bfloat].
505/// \returns A 256-bit vector of [8 x float] come from conversion of __A
506static __inline__ __m256 __DEFAULT_FN_ATTRS256
507_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
508  return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
509      (__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
510      16));
511}
512
513#undef __DEFAULT_FN_ATTRS128
514#undef __DEFAULT_FN_ATTRS256
515
516#endif
517#endif