#pragma once #if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ defined(__ARM_FEATURE_SVE) // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 #pragma GCC optimize("no-tree-vectorize") #endif // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] // // Note [Do not compile initializers with AVX] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // If you define a static initializer in this file, the initialization will use // AVX instructions because these object files are compiled with AVX enabled. // We need to avoid non-trivial global data in these architecture specific files // because there's no way to guard the global initializers with CPU capability // detection. // // See https://github.com/pytorch/pytorch/issues/37577 for an instance // of this bug in the past. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if defined(__GNUC__) #define __FORCE_INLINE __attribute__((always_inline)) inline #elif defined(_MSC_VER) #define __FORCE_INLINE __forceinline #endif #if defined(_MSC_FULL_VER) /* https://learn.microsoft.com/en-us/cpp/overview/compiler-versions?view=msvc-170 Use _MSC_FULL_VER to identify current compiler is msvc, Windows llvm will not have this definition. */ #define __msvc_cl__ #endif // These macros helped us unify vec_base.h #ifdef CPU_CAPABILITY_AVX512 #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(64))) #elif defined(_WIN32) #define __at_align__ __declspec(align(64)) #else #define __at_align__ #endif #define VECTOR_WIDTH 64 #define int_vector __m512i #elif defined(__aarch64__) && \ !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 // SVE code expects 256-vectors; leave that set for SVE? #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(16))) #elif defined(_WIN32) #define __at_align__ __declspec(align(16)) #else #define __at_align__ #endif #define VECTOR_WIDTH 16 #else // CPU_CAPABILITY_AVX512 #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(32))) #elif defined(_WIN32) #define __at_align__ __declspec(align(32)) #else #define __at_align__ #endif #define VECTOR_WIDTH 32 #define int_vector __m256i #endif // CPU_CAPABILITY_AVX512 namespace at::vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { // at::Half and at::BFloat16 should be treated as floating point template struct is_floating_point : std::integral_constant< bool, std::is_floating_point_v || std::is_same_v || std::is_same_v> {}; template constexpr bool is_floating_point_v = is_floating_point::value; template struct is_reduced_floating_point : std::integral_constant< bool, std::is_same_v || std::is_same_v> {}; template constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; template struct is_8bit_integer : std::integral_constant< bool, std::is_same_v || std::is_same_v> { }; template constexpr bool is_8bit_integer_v = is_8bit_integer::value; template struct int_of_size; #define DEFINE_INT_OF_SIZE(int_t) \ template <> \ struct int_of_size { \ using type = int_t; \ } DEFINE_INT_OF_SIZE(int64_t); DEFINE_INT_OF_SIZE(int32_t); DEFINE_INT_OF_SIZE(int16_t); DEFINE_INT_OF_SIZE(int8_t); #undef DEFINE_INT_OF_SIZE template using int_same_size_t = typename int_of_size::type; /** * Detect at compile time whether Vectorized has an explicit * specialization for T. (You are required to specialize this type * whenever you specialize Vectorized). Useful for generic algorithms * to decide whether to rely on a specialization being fast. For * example, they might choose to handle reduced-precision floating * point types directly if they're supported, or convert through float * if not. */ #if defined(__s390x__) template #else template #endif struct is_vec_specialized_for : std::bool_constant { }; template constexpr bool is_vec_specialized_for_v = is_vec_specialized_for::value; // NOTE: If you specialize Vectorized on a type, you must define all // operations! You must also specialize is_vec_specialized_for for // that type. // emulates Vectorized types #if defined(__s390x__) template #else template #endif struct Vectorized { private: __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; public: using value_type = T; using size_type = int; static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T); static constexpr size_type size() { return kSize; } Vectorized() : values{static_cast(0)} {} Vectorized(T val) { for (int i = 0; i != size(); i++) { values[i] = val; } } template < typename... Args, typename = std::enable_if_t<(sizeof...(Args) == size())>> Vectorized(Args... vals) : values{vals...} {} Vectorized(const T (&arr)[kSize]) { std::memcpy(values, arr, sizeof(values)); } // This also implies const T& operator[](int idx) const inline operator const T*() const { return values; } // This also implies T& operator[](int idx) inline operator T*() { return values; } // Return the values as char* for type punning auto as_bytes() const -> const char* { return reinterpret_cast(values); } template static Vectorized blend(const Vectorized& a, const Vectorized& b) { int64_t mask = mask_; Vectorized vector; for (const auto i : c10::irange(size())) { if (mask & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; } mask = mask >> 1; } return vector; } // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 #if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) blendv( const Vectorized& a, #else static Vectorized blendv( const Vectorized& a, #endif const Vectorized& b, const Vectorized& mask) { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); for (const auto i : c10::irange(size())) { if (buffer[i] & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; } } return vector; } template // step sometimes requires a higher precision type // (e.g., T=int, step_t=double) static Vectorized arange( T base = static_cast(0), step_t step = static_cast(1)) { Vectorized vector; for (const auto i : c10::irange(size())) { vector.values[i] = base + i * step; } return vector; } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { Vectorized vector; for (const auto i : c10::irange(size())) { if (i < count) { vector[i] = b[i]; } else { vector[i] = a[i]; } } return vector; } static Vectorized loadu(const void* ptr) { Vectorized vector; std::memcpy(vector.values, ptr, VECTOR_WIDTH); return vector; } static Vectorized loadu(const void* ptr, int64_t count) { Vectorized vector; std::memcpy(vector.values, ptr, count * sizeof(T)); return vector; } static Vectorized loadu_one_fourth(const void* ptr) { static_assert( std::is_same_v || std::is_same_v, "For byte types only"); return Vectorized::loadu(ptr, 8); } void store(void* ptr, int count = size()) const { std::memcpy(ptr, values, count * sizeof(T)); } int zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit // and others are translated to 0-bit int mask = 0; for (int i = 0; i < size(); ++i) { if (values[i] == static_cast(0)) { mask |= (1 << i); } } return mask; } Vectorized isnan() const { Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (_isnan(values[i])) { std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } return vector; } bool has_inf_nan() const { for (int64_t i = 0; i != size(); i++) { if (_isnan(values[i]) || _isinf(values[i])) { return true; } } return false; } // MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows // Arm64 // See // https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 #if defined(_WIN32) && defined(__aarch64__) && \ ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i < size(); i++) { ret[i] = f(values[i]); if (++i < size()) ret[i] = f(values[i]); } return ret; } T reduce(T (*const f)(T)) const { T ret = 0; for (int64_t i = 0; i < size(); i++) { ret = f(ret, values[i]); if (++i < size()) ret = f(ret, values[i]); } return ret; } #else Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } T reduce(T (*const f)(T)) const { T ret = 0; for (int64_t i = 0; i != size(); i++) { ret = f(ret, values[i]); } return ret; } #endif Vectorized map(T (*const f)(const T&)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } T reduce(T (*const f)(const T&)) const { T ret = 0; for (int64_t i = 0; i != size(); i++) { ret = f(ret, values[i]); } return ret; } template < typename other_t_abs = T, typename std::enable_if_t< !is_floating_point_v && !c10::is_complex::value, int> = 0> Vectorized abs() const { // other_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_abs must be T"); return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); } template < typename float_t_abs = T, typename std::enable_if_t, int> = 0> Vectorized abs() const { // float_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "float_t_abs must be T"); // Specifically deal with floating-point because the generic code above // won't handle -0.0 (which should result in 0.0) properly. return map([](T x) -> T { return std::abs(x); }); } template < typename complex_t_abs = T, typename std::enable_if_t::value, int> = 0> Vectorized abs() const { // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "complex_t_abs must be T"); // Specifically map() does not perform the type conversion needed by abs. return map([](T x) { return static_cast(std::abs(x)); }); } template < typename other_t_sgn = T, typename std::enable_if_t::value, int> = 0> Vectorized sgn() const { return map(at::native::sgn_impl); } template < typename other_t_angle = T, typename std::enable_if_t::value, int> = 0> Vectorized angle() const { // other_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_angle must be T"); return map(at::native::angle_impl); // compiler is unable to resolve the // overload without } template < typename complex_t_angle = T, typename std::enable_if_t::value, int> = 0> Vectorized angle() const { // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert( std::is_same_v, "complex_t_angle must be T"); return map([](T x) { return static_cast(std::arg(x)); }); } template < typename other_t_real = T, typename std::enable_if_t::value, int> = 0> Vectorized real() const { // other_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_real must be T"); return *this; } template < typename complex_t_real = T, typename std::enable_if_t::value, int> = 0> Vectorized real() const { // complex_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert( std::is_same_v, "complex_t_real must be T"); return map([](T x) { return static_cast(x.real()); }); } template < typename other_t_imag = T, typename std::enable_if_t::value, int> = 0> Vectorized imag() const { // other_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_imag must be T"); return Vectorized(0); } template < typename complex_t_imag = T, typename std::enable_if_t::value, int> = 0> Vectorized imag() const { // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert( std::is_same_v, "complex_t_imag must be T"); return map([](T x) { return static_cast(x.imag()); }); } template < typename other_t_conj = T, typename std::enable_if_t::value, int> = 0> Vectorized conj() const { // other_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_conj must be T"); return *this; } template < typename complex_t_conj = T, typename std::enable_if_t::value, int> = 0> Vectorized conj() const { // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert( std::is_same_v, "complex_t_conj must be T"); return map([](T x) { return static_cast(std::conj(x)); }); } Vectorized acos() const { return map(std::acos); } Vectorized acosh() const { return map(std::acosh); } Vectorized asin() const { return map(std::asin); } Vectorized asinh() const { return map(std::asinh); } Vectorized atan() const { return map(std::atan); } Vectorized atanh() const { return map(std::atanh); } Vectorized atan2(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::atan2(values[i], exp[i]); } return ret; } template < typename U = T, typename std::enable_if_t, int> = 0> Vectorized copysign(const Vectorized& sign) const { Vectorized ret; for (size_type i = 0; i < size(); i++) { ret[i] = c10::copysign(values[i], sign[i]); } return ret; } Vectorized erf() const { return map(std::erf); } Vectorized erfc() const { return map(std::erfc); } Vectorized erfinv() const { return map(calc_erfinv); } Vectorized exp() const { return map(std::exp); } Vectorized exp2() const { return map(exp2_impl); } Vectorized expm1() const { return map(std::expm1); } Vectorized exp_u20() const { return map(std::exp); } Vectorized fexp_u20() const { return map(std::exp); } Vectorized frac() const { return *this - this->trunc(); } template < typename U = T, typename std::enable_if_t, int> = 0> Vectorized fmod(const Vectorized& q) const { // U is for SFINAE purposes only. Make sure it is not changed. static_assert(std::is_same_v, "U must be T"); Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::fmod(values[i], q[i]); } return ret; } Vectorized log() const { return map(std::log); } Vectorized log10() const { return map(std::log10); } Vectorized log1p() const { return map(std::log1p); } template < typename other_t_log2 = T, typename std::enable_if_t::value, int> = 0> Vectorized log2() const { // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_log2 must be T"); return map(std::log2); } template < typename complex_t_log2 = T, typename std::enable_if_t::value, int> = 0> Vectorized log2() const { // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert( std::is_same_v, "complex_t_log2 must be T"); const T log_2 = T(std::log(2.0)); return Vectorized(map(std::log)) / Vectorized(log_2); } Vectorized ceil() const { return map(at::native::ceil_impl); } Vectorized cos() const { return map(std::cos); } Vectorized cosh() const { return map(std::cosh); } Vectorized floor() const { return map(at::native::floor_impl); } Vectorized hypot(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::hypot(values[i], b[i]); } return ret; } Vectorized i0() const { return map(calc_i0); } Vectorized i0e() const { return map(calc_i0e); } Vectorized digamma() const { return map(calc_digamma); } Vectorized igamma(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igamma(values[i], x[i]); } return ret; } Vectorized igammac(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igammac(values[i], x[i]); } return ret; } Vectorized neg() const { // NB: the trailing return type is needed because we need to coerce the // return value back to T in the case of unary operator- incuring a // promotion return map([](T x) -> T { return -x; }); } Vectorized nextafter(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::nextafter(values[i], b[i]); } return ret; } Vectorized round() const { // We do not use std::round because we would like to round midway numbers to // the nearest even integer. return map(at::native::round_impl); } Vectorized sin() const { return map(std::sin); } Vectorized sinh() const { return map(std::sinh); } Vectorized tan() const { return map(std::tan); } Vectorized tanh() const { return map(std::tanh); } Vectorized trunc() const { return map(at::native::trunc_impl); } Vectorized lgamma() const { return map(std::lgamma); } Vectorized sqrt() const { return map(std::sqrt); } Vectorized reciprocal() const { return map([](T x) { return (T)(1) / x; }); } Vectorized rsqrt() const { return map([](T x) { return (T)1 / std::sqrt(x); }); } Vectorized pow(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::pow(values[i], exp[i]); } return ret; } T reduce_add() const { return reduce([](T x, T y) -> T { return x + y; }); } T reduce_max() const { return reduce(std::max); } private: template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (op(values[i], other.values[i])) { std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } return vector; } public: Vectorized operator==(const Vectorized& other) const { return binary_pred(other, std::equal_to()); } Vectorized operator!=(const Vectorized& other) const { return binary_pred(other, std::not_equal_to()); } Vectorized operator>=(const Vectorized& other) const { return binary_pred(other, std::greater_equal()); } Vectorized operator<=(const Vectorized& other) const { return binary_pred(other, std::less_equal()); } Vectorized operator>(const Vectorized& other) const { return binary_pred(other, std::greater()); } Vectorized operator<(const Vectorized& other) const { return binary_pred(other, std::less()); } private: template inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { // 1 if the pred is true, otherwise 0. Vectorized vector; for (int i = 0; i != size(); ++i) { vector[i] = static_cast(op(values[i], other.values[i])); } return vector; } public: Vectorized eq(const Vectorized& other) const { return binary_pred_bool(other, std::equal_to()); } Vectorized ne(const Vectorized& other) const { return binary_pred_bool(other, std::not_equal_to()); } Vectorized gt(const Vectorized& other) const { return binary_pred_bool(other, std::greater()); } Vectorized ge(const Vectorized& other) const { return binary_pred_bool(other, std::greater_equal()); } Vectorized lt(const Vectorized& other) const { return binary_pred_bool(other, std::less()); } Vectorized le(const Vectorized& other) const { return binary_pred_bool(other, std::less_equal()); } }; template Vectorized inline operator-(const Vectorized& a) { return a.neg(); } // There is an implicit conversion that would make this work if // these operators weren't template functions, but they are template // functions (and can't be moved to be non-member friends defined in // the class body as suggested in // https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255 // because we have a lot of disparate specializations of // Vectorized). So, just explicitly make scalars work. #define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \ template \ Vectorized inline name(const Vectorized& a, T b) { \ return name(a, Vectorized(b)); \ } \ template \ Vectorized inline name(T a, const Vectorized& b) { \ return name(Vectorized(a), b); \ } #define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \ VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op) template Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] + b[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+) template Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] - b[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-) template Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] * b[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*) template Vectorized inline operator/(const Vectorized& a, const Vectorized& b) __ubsan_ignore_float_divide_by_zero__ { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] / b[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/) template , int> = 0> Vectorized inline operator%(const Vectorized& a, const Vectorized& b) __ubsan_ignore_float_divide_by_zero__ { return a - a / b * b; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%) template Vectorized inline operator||( const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] || b[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||) // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum) // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum) template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline clamp( const Vectorized& a, const Vectorized& min_vec, const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); } return c; } #define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \ template \ Vectorized inline name( \ const Vectorized& a, const Vectorized& b, T c) { \ return name(a, b, Vectorized(c)); \ } \ \ template \ Vectorized inline name( \ const Vectorized& a, T b, const Vectorized& c) { \ return name(a, Vectorized(b), c); \ } \ \ template \ Vectorized inline name(const Vectorized& a, T b, T c) { \ return name(a, Vectorized(b), Vectorized(c)); \ } \ \ template \ Vectorized inline name( \ T a, const Vectorized& b, const Vectorized& c) { \ return name(Vectorized(a), b, c); \ } \ \ template \ Vectorized inline name(T a, const Vectorized& b, T c) { \ return name(Vectorized(a), b, Vectorized(c)); \ } \ \ template \ Vectorized inline name(T a, T b, const Vectorized& c) { \ return name(Vectorized(a), Vectorized(b), c); \ } VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp) template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline clamp_max( const Vectorized& a, const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max) template < class T, typename std::enable_if_t::value, int> = 0> Vectorized inline clamp_min( const Vectorized& a, const Vectorized& min_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; } return c; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min) struct Vectorizedi; #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template static inline Vectorized bitwise_binary_op( const Vectorized& a, const Vectorized& b, Op op) { int_vector buffer; #if defined(CPU_CAPABILITY_AVX2) int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); #elif defined(CPU_CAPABILITY_AVX512) int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); #endif buffer = op(a_buffer, b_buffer); __at_align__ T results[Vectorized::size()]; #if defined(CPU_CAPABILITY_AVX2) _mm256_store_si256(reinterpret_cast(results), buffer); #elif defined(CPU_CAPABILITY_AVX512) _mm512_store_si512(reinterpret_cast(results), buffer); #endif return Vectorized::loadu(results); } template < class T, typename std::enable_if_t< !std::is_base_of>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is // always_inline #if defined(CPU_CAPABILITY_AVX2) return bitwise_binary_op( a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) return bitwise_binary_op( a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); #endif } template < class T, typename std::enable_if_t< !std::is_base_of>::value, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is // always_inline #if defined(CPU_CAPABILITY_AVX2) return bitwise_binary_op( a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) return bitwise_binary_op( a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); #endif } template < class T, typename std::enable_if_t< !std::is_base_of>::value, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is // always_inline #if defined(CPU_CAPABILITY_AVX2) return bitwise_binary_op( a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) return bitwise_binary_op( a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); #endif } #else template auto load(char const* data) -> T { T ret; std::memcpy(&ret, data, sizeof(ret)); return ret; } template static inline Vectorized bitwise_binary_op( const Vectorized& a, const Vectorized& b, Op op) { static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); __at_align__ intmax_t buffer[element_no]; static_assert( VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); static_assert( sizeof(buffer) == sizeof(Vectorized), "sizeof(buffer) must match sizeof(Vectorized)"); // We should be using memcpy in order to respect the strict aliasing rule // see: https://github.com/pytorch/pytorch/issues/66119 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) const auto* a_data = a.as_bytes(); const auto* b_data = b.as_bytes(); // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) for (auto& out : buffer) { out = op(load(a_data), load(b_data)); a_data += sizeof(intmax_t); b_data += sizeof(intmax_t); } assert(a_data == a.as_bytes() + sizeof(a)); assert(b_data == b.as_bytes() + sizeof(b)); return Vectorized::loadu(buffer); } template < class T, typename std:: enable_if_t>, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_and()); } template < class T, typename std:: enable_if_t>, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_or()); } template < class T, typename std:: enable_if_t>, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&) VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|) VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^) template < class T, typename std:: enable_if_t>, int> = 0> inline Vectorized operator~(const Vectorized& a) { using int_t = int_same_size_t; Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 return a ^ ones; } template Vectorized inline operator<<( const Vectorized& a, const Vectorized& b) { constexpr T max_shift = sizeof(T) * CHAR_BIT; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; if ((static_cast>(shift) < 0) || (shift >= max_shift)) { c[i] = 0; } else { c[i] = static_cast>(a[i]) << shift; } } return c; } template Vectorized inline operator>>( const Vectorized& a, const Vectorized& b) { // right shift value to retain sign bit for signed and no bits for unsigned constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; if ((static_cast>(shift) < 0) || (shift >= max_shift)) { c[i] = a[i] >> max_shift; } else { c[i] = a[i] >> shift; } } return c; } template inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { a = a + b; return a; } template inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { a = a - b; return a; } template inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { a = a / b; return a; } template inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { a = a % b; return a; } template inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { a = a * b; return a; } template inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { a = a << b; return a; } template inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { a = a >> b; return a; } template inline Vectorized fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; } VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) template inline Vectorized fnmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return -(a * b) + c; } VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmadd) template inline Vectorized fmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b - c; } VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) template inline Vectorized fnmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { return -(a * b) - c; } VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fnmsub) template Vectorized inline operator&&( const Vectorized& a, const Vectorized& b) { Vectorized ret; for (int i = 0; i != Vectorized::size(); i++) { ret[i] = a[i] && b[i]; } return ret; } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&) template std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized< T>> inline gather(T const* base_addr, const Vectorized>& vindex) { static constexpr int size = Vectorized::size(); int_same_size_t index_arr[size]; vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } return Vectorized::loadu(static_cast(buffer)); } template std:: enable_if_t> inline mask_gather( const Vectorized& src, T const* base_addr, const Vectorized>& vindex, Vectorized& mask) { static constexpr int size = Vectorized::size(); T src_arr[size]; int_same_size_t mask_arr[size]; // use int type so we can logical and int_same_size_t index_arr[size]; src.store(static_cast(src_arr)); mask.store(static_cast(mask_arr)); vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { if (mask_arr[i] & 0x01) { // check highest bit buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } else { buffer[i] = src_arr[i]; } } mask = Vectorized(static_cast(0)); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } // Cast a given vector to another type without changing the bits representation. // So a Vectorized of 512 bits containing all ones can be cast to a // Vectorized of 512 bits containing all ones (i.e., eight negative // 1s). A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). // There is a struct here because we don't have static_if and I can't // partially specialize a templated function. template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { src_t src_arr[Vectorized::size()]; src.store(static_cast(src_arr)); return Vectorized::loadu(static_cast(src_arr)); } }; template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { return src; } }; template inline Vectorized cast(const Vectorized& src) { return CastImpl::apply(src); } template > inline Vectorized convert_to_int_of_same_size( const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr = {}; src.store(static_cast(src_arr.data())); std::array buffer; std::transform( src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { return static_cast(x); }); return Vectorized::loadu(static_cast(buffer.data())); } template > inline Vectorized convert_to_fp_of_same_size( const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; std::transform( src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { return static_cast(x); }); return Vectorized::loadu(static_cast(buffer.data())); } // clang-format off // Example inputs for AVX512: // a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} // returns: // Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} // Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} // returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // clang-format on template inline std::enable_if_t< Vectorized::size() % 2 == 0, std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; T a_arr[size]; T b_arr[size]; T buffer1[size]; T buffer2[size]; a.store(static_cast(a_arr)); b.store(static_cast(b_arr)); for (const auto i : c10::irange(half_size)) { buffer1[i] = a_arr[i * 2]; buffer1[half_size + i] = b_arr[i * 2]; buffer2[i] = a_arr[i * 2 + 1]; buffer2[half_size + i] = b_arr[i * 2 + 1]; } return std::make_pair( Vectorized::loadu(static_cast(buffer1)), Vectorized::loadu(static_cast(buffer2))); } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2) // clang-format off // inverse operation of deinterleave2 // Example inputs for AVX512: // a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} // returns, for AVX512: // Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} // Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} // clang-format on template inline std::enable_if_t< Vectorized::size() % 2 == 0, std::pair, Vectorized>> interleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; T a_arr[size]; T b_arr[size]; T buffer1[size]; T buffer2[size]; a.store(static_cast(a_arr)); b.store(static_cast(b_arr)); for (const auto i : c10::irange(half_size)) { buffer1[i * 2] = a_arr[i]; buffer1[i * 2 + 1] = b_arr[i]; buffer2[i * 2] = a_arr[half_size + i]; buffer2[i * 2 + 1] = b_arr[half_size + i]; } return std::make_pair( Vectorized::loadu(static_cast(buffer1)), Vectorized::loadu(static_cast(buffer2))); } VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2) #undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC #undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP #undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC template inline void convert(const src_T* src, dst_T* dst, int64_t n) { #ifndef _MSC_VER #pragma unroll #endif for ([[maybe_unused]] const auto i : c10::irange(n)) { *dst = c10::convert(c10::load(src)); src++; dst++; } } template inline Vectorized flip(const Vectorized& data) { static constexpr int size = Vectorized::size(); T output[size]; T buffer[size]; data.store(static_cast(buffer)); for (const auto i : c10::irange(size)) { output[i] = buffer[size - i - 1]; } return Vectorized::loadu(static_cast(output)); } // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. // `ld_src` is the leading dimension of `src` and `ld_dst` is the leading // dimension of `dst`. template inline void transpose_mxn( const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { dst[j * ld_dst + i] = src[i * ld_src + j]; } } } template inline void transpose_mxn( const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } } // namespace CPU_CAPABILITY } // namespace at::vec // additional headers for more operations that depend on vec_base #include #include #include