#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] #include // clang-format off #include #include #include #include #include #include #include #include #include #include #include // clang-format on #include #include #include #include #include namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { stream << val.val_; return stream; } inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { stream << static_cast(val.val_); return stream; } inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { stream << static_cast(val.val_); return stream; } template std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { T buf[Vectorized::size()]; vec.store(buf); stream << "vec["; for (int i = 0; i != Vectorized::size(); i++) { if (i != 0) { stream << ", "; } stream << buf[i]; } stream << "]"; return stream; } #if defined(CPU_CAPABILITY_AVX512) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> inline Vectorized cast(const Vectorized& src) { return _mm512_castpd_ps(src); } template <> inline Vectorized cast(const Vectorized& src) { return _mm512_castps_pd(src); } template <> inline Vectorized cast(const Vectorized& src) { return _mm512_castsi512_ps(src); } template <> inline Vectorized cast( const Vectorized& src) { return _mm512_castsi512_pd(src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #ifndef _MSC_VER // MSVC is not working well on complex function overload. template std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< double>> inline gather(const double* base_addr, const Vectorized& vindex) { return _mm512_i64gather_pd(vindex, base_addr, scale); } template std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< float>> inline gather(const float* base_addr, const Vectorized& vindex) { return _mm512_i32gather_ps(vindex, base_addr, scale); } #endif // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #ifndef _MSC_VER // MSVC is not working well on complex function overload. template std:: enable_if_t> inline mask_gather( const Vectorized& src, const double* base_addr, const Vectorized& vindex, Vectorized& mask) { auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF)); auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ); return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale); } template std:: enable_if_t> inline mask_gather( const Vectorized& src, const float* base_addr, const Vectorized& vindex, Vectorized& mask) { auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF)); auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ); return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale); } #endif // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> Vectorized inline convert_to_int_of_same_size( const Vectorized& src) { return _mm512_cvtpd_epi64(src); } template <> Vectorized inline convert_to_int_of_same_size( const Vectorized& src) { return _mm512_cvttps_epi32(src); } template <> Vectorized inline convert_to_fp_of_same_size( const Vectorized& src) { return _mm512_cvtepi64_pd(src); } template <> Vectorized inline convert_to_fp_of_same_size( const Vectorized& src) { return _mm512_cvtepi32_ps(src); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a3, a3, a4, a5, a6, a7} // b = {b0, b1, b2, b3, b4, b5, b6, b7} // group cols crossing lanes: // return {a0, b0, a1, b1, a2, b2, a3, b3} // {a4, b4, a5, b5, a6, b6, a7, b7} __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); return std::make_pair( _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); } template <> std::pair, Vectorized> inline interleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, // a15} b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, // b14, b15} // // return: // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, // b15} __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); __m512i idx2 = _mm512_set_epi32( 31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); return std::make_pair( _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> std::pair, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1, a2, b2, a3, b3} // b = {a4, b4, a5, b5, a6, b6, a7, b7} // output: // return {a0, a1, a2, a3, a4, a5, a6, a7} // {b0, b1, b2, b3, b4, b5, b6, b7} // The members of indices have been written in binary format for better // understandability __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); return std::make_pair( _mm512_mask_permutex2var_pd(a, 0xff, idx1, b), _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); } template <> std::pair, Vectorized> inline deinterleave2( const Vectorized& a, const Vectorized& b) { // inputs: // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, // a15, b15} // output: // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, // a15} // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, // b15} __m512i idx1 = _mm512_set_epi32( 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); __m512i idx2 = _mm512_set_epi32( 31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); return std::make_pair( _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template <> inline Vectorized flip(const Vectorized& v) { const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); return _mm512_permutexvar_ps(mask, v); } template <> inline Vectorized flip(const Vectorized& v) { const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); return _mm512_permutexvar_pd(mask, v); } template <> inline Vectorized flip(const Vectorized& v) { const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); return _mm512_permutexvar_epi64(mask, v); } template <> inline Vectorized flip(const Vectorized& v) { const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); return _mm512_permutexvar_epi32(mask, v); } template <> inline Vectorized flip(const Vectorized& v) { const __m512i mask = _mm512_set_epi16( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31); return _mm512_permutexvar_epi16(mask, v); } inline __m512i flip8(const __m512i& v) { const __m512i mask1 = _mm512_set_epi8( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6); auto reversed_vec = _mm512_shuffle_epi8(v, mask1); return _mm512_permutexvar_epi64(mask2, reversed_vec); } template <> inline Vectorized flip(const Vectorized& v) { return flip8(v); } template <> inline Vectorized flip(const Vectorized& v) { return flip8(v); } inline Vectorized operator&&( const Vectorized& self, const Vectorized& other) { const __m512i* self_ = reinterpret_cast(self.as_bytes()); const __m512i* other_ = reinterpret_cast(other.as_bytes()); __m512i out = _mm512_and_si512(*self_, *other_); Vectorized ret; // We do not have a constructer that takes __m512i, so we need to memcpy std::memcpy(ret, &out, ret.size() * sizeof(bool)); return ret; } #endif // defined(CPU_CAPABILITY_AVX512) } // namespace CPU_CAPABILITY } // namespace vec } // namespace at