master
  1/*===------------- amxtf32intrin.h - AMX_TF32 intrinsics -*- C++ -*---------===
  2 *
  3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4 * See https://llvm.org/LICENSE.txt for license information.
  5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6 *
  7 *===------------------------------------------------------------------------===
  8 */
  9
 10#ifndef __IMMINTRIN_H
 11#error "Never use <amxtf32intrin.h> directly; include <immintrin.h> instead."
 12#endif // __IMMINTRIN_H
 13
 14#ifndef __AMX_TF32INTRIN_H
 15#define __AMX_TF32INTRIN_H
 16#ifdef __x86_64__
 17
 18#define __DEFAULT_FN_ATTRS_TF32                                                \
 19  __attribute__((__always_inline__, __nodebug__, __target__("amx-tf32")))
 20
 21/// Do Matrix Multiplication of \a a and \a b, and then do Matrix Plus
 22/// with \a srcdst.
 23/// All the calculation is base on float32 but with the lower 13-bit set to 0.
 24///
 25/// \headerfile <immintrin.h>
 26///
 27/// \code
 28/// void _tile_mmultf32ps(constexpr int srcdst, constexpr int a, \
 29///                       constexpr int b);
 30/// \endcode
 31///
 32/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
 33///
 34/// \param srcdst
 35/// 	The destination tile. Max size is 1024 Bytes.
 36/// \param a
 37/// 	The 1st source tile. Max size is 1024 Bytes.
 38/// \param b
 39/// 	The 2nd source tile. Max size is 1024 Bytes.
 40///
 41/// \code{.operation}
 42/// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
 43///	dword[12:0] := 0
 44///	dword[31:13] := x[31:13]
 45///	return dword
 46/// }
 47///
 48/// DEFINE silence_snan_fp32(x[31:0]) {
 49/// 	IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
 50/// 		x.fraction[22] := 1
 51/// 	return x
 52/// }
 53///
 54/// elements_a := a.colsb / 4
 55/// elements_dest := srcdst.colsb / 4
 56///
 57/// FOR m = 0 TO (srcdst.rows-1)
 58/// 	tmp[511:0] := 0
 59/// 	FOR k = 0 TO (elements_a-1)
 60/// 		FOR n = 0 TO (elements_dest-1)
 61/// 			af := silence_snan_fp32(a.row[m].fp32[k])
 62/// 			bf := silence_snan_fp32(b.row[k].fp32[n])
 63/// 			tmp.fp32[n] += zero_lower_mantissa_bits_fp32(af)
 64/// 					* zero_lower_mantissa_bits_fp32(bf)
 65/// 		ENDFOR
 66/// 	ENDFOR
 67///
 68/// 	FOR n = 0 TO (elements_dest-1)
 69/// 		tmp.fp32[n] += srcdst.row[m].fp32[n]
 70/// 	ENDFOR
 71///	write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
 72///
 73/// ENDFOR
 74///
 75/// zero_upper_rows(srcdst, srcdst.rows)
 76/// zero_tileconfig_start()
 77/// \endcode
 78#define _tile_mmultf32ps(srcdst, a, b)                                         \
 79  __builtin_ia32_tmmultf32ps((srcdst), (a), (b))
 80
 81static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32
 82_tile_mmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
 83                          _tile1024i dst, _tile1024i src1, _tile1024i src2) {
 84  return __builtin_ia32_tmmultf32ps_internal(m, n, k, dst, src1, src2);
 85}
 86
 87/// Do Matrix Multiplication of src0 and src1, and then do Matrix Plus with dst.
 88/// All the calculation is base on float32 but with the lower 13-bit set to 0.
 89///
 90/// \headerfile <immintrin.h>
 91///
 92/// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
 93///
 94/// \param dst
 95///    The destination tile. Max size is 1024 Bytes.
 96/// \param src0
 97///    The 1st source tile. Max size is 1024 Bytes.
 98/// \param src1
 99///    The 2nd source tile. Max size is 1024 Bytes.
100__DEFAULT_FN_ATTRS_TF32
101static void __tile_mmultf32ps(__tile1024i *dst, __tile1024i src0,
102                              __tile1024i src1) {
103  dst->tile = _tile_mmultf32ps_internal(src0.row, src1.col, src0.col, dst->tile,
104                                        src0.tile, src1.tile);
105}
106
107#endif // __x86_64__
108#endif // __AMX_TF32INTRIN_H