#pragma once // On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to // access constants such as M_SQRT2 and M_2_SQRTPI. #ifdef _WIN32 #define _USE_MATH_DEFINES #include #endif // _WIN32 #include #include // For c10::is_reduced_floating_point_v. namespace at::native { inline namespace CPU_CAPABILITY { /** * Return a function object that calculates ELU with the given * parameters on its input element. ParamT is the type of the input * and output to the ELU, and MathT is the type (possibly * higher-precision, e.g. float if ParamT is reduced-precision float) * in which to do intermediate calculations. */ template auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) { const auto negcoef = alpha * scale; const auto poscoef = scale; const auto negiptcoef = input_scale; return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT { return MathT(a) < MathT(0) ? std::expm1(MathT(a) * negiptcoef) * negcoef : MathT(a) * poscoef; }; } /** * Return a function object that calculates ELU with the given * parameters on its input element. The function object takes and * returns Vectorized. */ template , bool> = true> auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) { const vec::Vectorized negcoef_vec(alpha * scale); const vec::Vectorized poscoef_vec(scale); const vec::Vectorized negiptcoef_vec(input_scale); const vec::Vectorized zero_vec(static_cast(0)); return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized a) -> vec::Vectorized { const auto cmp = a >= zero_vec; if (!cmp.zero_mask()) { return a * poscoef_vec; } else { return vec::Vectorized::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp); } }; } /** * Return a function object that calculates ELU with the given * parameters on its input element. The function object takes and * returns Vectorized, and Vectorized is the type * (possibly higher-precision) in which to do intermediate * calculations. */ template , bool> = true> auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) { // Takes float->float. const auto float_func = get_vectorized_elu_elementwise_func(alpha, scale, input_scale); return [float_func](vec::Vectorized a) -> vec::Vectorized { auto [a0, a1] = vec::convert_to_float(a); auto res0 = float_func(a0); auto res1 = float_func(a1); return vec::convert_from_float(res0, res1); }; } } // namespace CPU_CAPABILITY } // namespace at::native