#pragma once // On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to // access constants such as M_SQRT2 and M_2_SQRTPI. #ifdef _WIN32 #define _USE_MATH_DEFINES #include #include #endif // _WIN32 #include #include // For c10::is_reduced_floating_point_v. namespace at::native { inline namespace CPU_CAPABILITY { constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5; constexpr double kGeluKappa = 0.044715; template using reduced_fp_to_float_t = std::conditional_t, float, T>; template , bool> = true> float reduced_fp_to_float(T x) { return float(x); } template , bool> = true> T reduced_fp_to_float(T x) { return x; } template T scalar_gelu_approximated_with_tanh(T x) { using opmath_t = reduced_fp_to_float_t; auto x_float = reduced_fp_to_float(x); auto x_cube = x_float * x_float * x_float; auto inner = opmath_t(kGeluBeta) * (x_float + opmath_t(kGeluKappa) * x_cube); return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner)); } template , bool> = true> vec::Vectorized vectorized_gelu_approximated_with_tanh(vec::Vectorized x) { const vec::Vectorized kPointFiveVec(T(0.5)); const vec::Vectorized kOneVec(T(1)); const vec::Vectorized kGeluBetaVec((T(kGeluBeta))); const vec::Vectorized kGeluKappaVec((T(kGeluKappa))); auto x_cube = x * x * x; vec::Vectorized inner_vec = kGeluBetaVec * (x + kGeluKappaVec * x_cube); return kPointFiveVec * x * (kOneVec + inner_vec.tanh()); } template , bool> = true> vec::Vectorized vectorized_gelu_approximated_with_tanh(vec::Vectorized x) { auto [x0, x1] = at::vec::convert_to_float(x); return at::vec::convert_from_float( vectorized_gelu_approximated_with_tanh(x0), vectorized_gelu_approximated_with_tanh(x1)); } template T scalar_gelu(T x) { using opmath_t = reduced_fp_to_float_t; const auto kAlpha = opmath_t(M_SQRT1_2); return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha)); } template, bool> = true> vec::Vectorized vectorized_gelu(vec::Vectorized x) { const vec::Vectorized kAlphaVec(T(M_SQRT1_2)); const vec::Vectorized kOneVec(T(1)); const vec::Vectorized kPointFiveVec(T(0.5)); return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf()); } template, bool> = true> vec::Vectorized vectorized_gelu(vec::Vectorized x) { auto [x0, x1] = at::vec::convert_to_float(x); return at::vec::convert_from_float(vectorized_gelu(x0), vectorized_gelu(x1)); } } // namespace CPU_CAPABILITY } // namespace at::native