#pragma once #include #include namespace at::vec { inline namespace CPU_CAPABILITY { /** * The `VecMask` class provides a convenient interface for working with * vectorized masks in SIMD operations. It encapsulates a `Vectorized` * mask that can be directly usable in masked vectorized operations. It provides * various methods for manipulating and accessing the mask elements: * 1. `from` and `to`: Conversion between a vector of boolean values and a * vectorized mask. * 2. `cast`: Casts the mask to a different base type. * 3. `all_zero`: Checks if all mask elements are zero. * 4. `is_masked`: Checks if a specific element is masked. * 5. `loadu`: Loads data from memory using the mask. * 6. `all_masked`: Checks if all mask elements are masked. * * Some helper template classes are provided to simplify the specialization of * the `VecMask` for the specific CPU arch: * 1. `VecMaskLoad`: Loads data from memory using the mask. * 2. `VecMaskTo`: Converts the mask to boolean. * 3. `VecMaskCast`: Casts the mask to a different base type. * */ template class VecMask; template < typename data_t, int data_n, typename mask_t, int mask_n, typename Enabled = void> struct VecMaskLoad { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { constexpr typename VecMask::size_type size = VecMask::size(); static_assert(VectorizedN::size() >= size); __at_align__ data_t data[size]; __at_align__ mask_t mask[size]; auto mask_ = VectorizedN(vec_mask); mask_.store(mask); for (int i = 0; i < size; i++) { data[i] = mask[i] ? ptr[i] : static_cast(0); } return VectorizedN::loadu(data, size); } }; template < typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void> struct VecMaskTo { static inline VecMask apply( const VecMask& vec_mask) { auto zeros = VectorizedN(static_cast(0)); auto ones = VectorizedN(static_cast(1)); return VectorizedN::blendv( zeros, ones, vec_mask.template cast()); } }; template < typename dst_t, int dst_n, typename src_t, int src_n, typename Enabled = void> struct VecMaskCast { static inline VecMask apply( const VecMask& vec_mask) { return VecMask::from(VectorizedN(vec_mask)); } }; template struct VecMaskCast { static inline VecMask apply(const VecMask& vec_mask) { return vec_mask; } }; template struct VecMaskCheck { static inline bool all_zero(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); return std::all_of(mask, mask + VectorizedN::size(), [](T m) { return m == static_cast(0); }); } static inline bool all_masked(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); return std::all_of(mask, mask + VectorizedN::size(), [](T m) { return m != static_cast(0); }); } static inline bool is_masked(const VectorizedN& vec_mask, int i) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); return mask[i] != static_cast(0); } }; template class VecMask { public: using size_type = int; static constexpr size_type size() { return VectorizedN::size(); } private: VectorizedN mask_; public: VecMask() : mask_(static_cast(0)) {} VecMask(const VectorizedN& mask) : mask_(mask) {} template = 0> VecMask(const Vectorized& mask) : mask_(mask) {} template static VecMask from(const VectorizedN& b_vec) { __at_align__ U b_buf[size()]; if constexpr (size() >= VectorizedN::size()) { b_vec.store(b_buf); for (int i = VectorizedN::size(); i < size(); i++) { b_buf[i] = static_cast(0); } } else { b_vec.store(b_buf, size()); } return from(b_buf); } template static VecMask from(U b) { using int_t = int_same_size_t; T mask = b ? c10::bit_cast((int_t)(~(int_t)0)) : (T)0; return VectorizedN(mask); } template static VecMask from(U* b) { using int_t = int_same_size_t; __at_align__ T mask[size()]; #ifndef __msvc_cl__ #pragma unroll #endif for (int i = 0; i < size(); i++) { *(int_t*)(mask + i) = b[i] ? ~(int_t)0 : (int_t)0; } return VectorizedN(VectorizedN::loadu(mask)); } static VecMask blendv( const VecMask& c, const VecMask& b, const VecMask& a) { VectorizedN result = VectorizedN::blendv( VectorizedN(c), VectorizedN(b), VectorizedN(a)); return result; } static VecMask set( const VecMask& a, const VecMask& b, int64_t count = size()) { VectorizedN result = VectorizedN::set( VectorizedN(a), VectorizedN(b), count); return result; } void store(bool* b, int count = size()) { constexpr int L = (VectorizedN::size() + Vectorized::size() - 1) / Vectorized::size(); auto res = this->to(); res.store(b, count); return; } template = 2, int> = 0> inline VectorizedN to() const { return VecMaskTo::apply(*this); } template = 0> inline Vectorized to() const { return VecMaskTo::apply(*this); } template inline VecMask cast() const { return VecMaskCast::apply(*this); } inline bool all_zero() const { return VecMaskCheck::all_zero(mask_); } inline bool all_masked() const { return VecMaskCheck::all_masked(mask_); } inline bool is_masked(int i) const { return VecMaskCheck::is_masked(mask_, i); } inline operator VectorizedN() const { return mask_; } template = 0> inline operator Vectorized() const { return mask_[0]; } inline Vectorized operator[](int i) const { return mask_[i]; } template < typename U, int L, std::enable_if_t= 2 && VectorizedN::size() >= size(), int> = 0> VectorizedN loadu(const U* ptr) const { return VecMaskLoad::apply(ptr, *this); } template < typename U, int L, std::enable_if_t::size() >= size(), int> = 0> Vectorized loadu(const U* ptr) const { return VecMaskLoad::apply(ptr, *this); } }; #define VEC_MASK_DEFINE_UNARY_OP_GLOBAL(op) \ template \ inline VecMask op(const VecMask& a) { \ return op(VectorizedN(a)); \ } #define VEC_MASK_DEFINE_BINARY_OP_GLOBAL(op) \ template < \ typename T, \ int N, \ typename V, \ int M, \ std::enable_if_t::size() == VecMask::size(), int> = \ 0> \ inline VecMask op(const VecMask& a, const VecMask& b) { \ return op( \ VectorizedN(a), VectorizedN(b.template cast())); \ } #define VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(op, EXPR) \ template < \ typename T, \ int N, \ typename V, \ int M, \ std::enable_if_t::size() == VecMask::size(), int> = \ 0> \ inline VecMask op(const VecMask& a, const VecMask& b) { \ return EXPR; \ } VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>=, (a == b) | (a > b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<=, (a == b) | (a < b)) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator!=, (a ^ b)) #undef VEC_MASK_DEFINE_UNARY_OP_GLOBAL #undef VEC_MASK_DEFINE_BINARY_OP_GLOBAL #undef VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL } // namespace CPU_CAPABILITY } // namespace at::vec