#pragma once // This file provides two functions to help write GPU elementwise kernels: // // gpu_kernel(TensorIterator iter, ) // gpu_kernel_with_scalars(TensorIterator iter, ) // // The gpu_kernel_with_scalars generates specializations that support a // single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar // is lifted to a kernel parameter instead of copying to device memory. // This should be used in conjunction with TensorIterator::allow_cpu_scalars_, // which is the default for TensorIterator::binary_op. Otherwise, all inputs // and the output must be on the GPU. // // For example, to write a reciprocal kernel for GPU float Tensors: // // gpu_kernel(iter, []GPU_LAMBDA(float a) { // return 1.0f / a; // }); // // To write a multiplication kernel for GPU float Tensors where one argument // may be a CPU scalar: // // gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { // return a * b; // }); // // See BinaryOpsKernel.cu for the complete implementation // #include #include #include #include #include #include #include #include #include #include #ifdef __NVCC__ #define ASSERT_HOST_DEVICE_LAMBDA(type) \ static_assert( \ __nv_is_extended_host_device_lambda_closure_type(type), \ #type " must be a __host__ __device__ lambda") #else #define ASSERT_HOST_DEVICE_LAMBDA(type) #endif namespace at::native { #ifdef USE_ROCM // Custom configuration for vectorized elementwise kernel // with template instantiation. namespace vectorized_templated_config { constexpr int num_threads() { return 512; } constexpr int elems_per_thread() { return 32; } constexpr int block_work_size() { return elems_per_thread() * num_threads(); } } // namespace vectorized_templated_config #endif template constexpr auto sum_of_sizes(args_t args, std::index_sequence) { if constexpr (sizeof...(Is) == 0) { return 0; } else { return (sizeof(std::tuple_element_t) + ...); } } #ifdef USE_ROCM template constexpr auto elems_per_thread(){ if constexpr (io_sizes == 1) { return 16; } else if constexpr (io_sizes < 4) { return 8; } else { return 4; } } #else template constexpr auto elems_per_thread(){ if constexpr (io_sizes == 1) { return 16; } else { return 8; } } #endif //thread work size of 8 regresses the perf of elementwise kernel on cuda //this doesn't change ROCm behavior as thread_work_size is already 4 on ROCm constexpr int elementwise_thread_work_size() {return 4;} constexpr int elementwise_block_work_size() { return elementwise_thread_work_size() * num_threads(); } template constexpr auto io_block_work_size() { return num_threads() * elems_per_thread(); } #ifdef USE_ROCM template constexpr auto input_size(args_t args, std::index_sequence) { if constexpr (sizeof...(Is) == 0) { return 0; } else { return sizeof(std::tuple_element_t<0, args_t>); } } template constexpr auto calc_optimal_vec_size() { static_assert(vec_size != 0); static_assert(io_size != 0); if constexpr (io_size == 1 && vec_size >= 16) { return 16; } else if constexpr (io_size <= 2 && vec_size >= 8) { return 8; } else if constexpr (io_size <= 4 && vec_size >= 4) { return 4; } else if constexpr (vec_size >= 4) { return 4; } else if constexpr (vec_size >= 2) { return 2; } else { return 1; } } #endif template constexpr auto calc_io_size(){ using traits = function_traits; using args_t = typename traits::ArgsTuple; #ifdef USE_ROCM constexpr auto input_size = at::native::input_size(args_t{}, std::make_index_sequence>{}); constexpr auto output_size = sizeof(typename traits::result_type); return (input_size > 0) ? ((input_size < output_size) ? input_size : output_size) : output_size; #else constexpr auto input_size = at::native::sum_of_sizes(args_t{}, std::make_index_sequence>{}); constexpr auto output_size = sizeof(typename traits::result_type); return input_size + output_size; #endif } #ifndef USE_ROCM // To save on binary size of libtorch_cuda.so, we split the vectorized_elementwise_kernel // into two: one for vec_size=8 and one for vec_size=[2, 4], since vec8 is going to be // used on sm_90 and sm_100 exclusively. template C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { if constexpr (vec_size == 8) { #if __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 using traits = function_traits; constexpr auto io_size = calc_io_size(); int remaining = N - io_block_work_size() * blockIdx.x; if (remaining < io_block_work_size()) { // if this block handles the reminder, // just do a naive unrolled loop auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); auto policy = memory::policies::unroll< array_t, decltype(input_calc), decltype(output_calc), memory::LoadWithoutCast, memory::StoreWithoutCast, elems_per_thread()>( data, remaining, input_calc, output_calc, loader, storer); elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access elementwise_kernel_helper( f, memory::policies::vectorized()>(data)); } #endif // __CUDA_ARCH__ == 900 || __CUDA_ARCH__ == 1000 } else { using traits = function_traits; constexpr auto io_size = calc_io_size(); int remaining = N - io_block_work_size() * blockIdx.x; if (remaining < io_block_work_size()) { // if this block handles the reminder, // just do a naive unrolled loop auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); auto policy = memory::policies::unroll< array_t, decltype(input_calc), decltype(output_calc), memory::LoadWithoutCast, memory::StoreWithoutCast, elems_per_thread()>( data, remaining, input_calc, output_calc, loader, storer); elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access elementwise_kernel_helper( f, memory::policies::vectorized()>(data)); } } } #else // USE_ROCM template C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { using traits = function_traits; constexpr auto io_size = calc_io_size(); #if defined(USE_ROCM) && defined(__gfx942__) // Similar check in launch_vectorized_kernel() as well. Both should be in sync. constexpr int tws = 16; #else constexpr int tws = elems_per_thread(); #endif constexpr int bws = tws * num_threads(); int remaining = N - bws * blockIdx.x; if (remaining < bws) { // if this block handles the reminder, // just do a naive unrolled loop auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); auto policy = memory::policies::unroll< array_t, decltype(input_calc), decltype(output_calc), memory::LoadWithoutCast, memory::StoreWithoutCast, tws>( data, remaining, input_calc, output_calc, loader, storer); elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access constexpr auto optimal_vec_size = calc_optimal_vec_size(); elementwise_kernel_helper( f, memory::policies::vectorized(data)); } } #endif // USE_ROCM template < typename func_t, typename array_t, int elems_per_thread, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t> C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void unrolled_elementwise_kernel( int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s) { int remaining = N - elems_per_thread * num_threads() * blockIdx.x; auto policy = memory::policies:: unroll( data, remaining, ic, oc, l, s); elementwise_kernel_helper(f, policy); } // this function assume trivial 1d and no dynamic casting template static inline void launch_vectorized_kernel( int64_t N, const func_t& f, array_t data) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); using traits = function_traits; constexpr auto io_size = calc_io_size(); auto stream = at::cuda::getCurrentCUDAStream(); #ifdef USE_ROCM int vec_size = memory::can_vectorize_up_to(data); c10::DeviceIndex curDevice = -1; AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice)); // Similar check in vectorized_elementwise_kernel() as well. Both should be in sync. int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? 16 : elems_per_thread(); #else using cpp_type = typename function_traits::result_type; const uint16_t max_vec_size = memory::can_vectorize_up_to(data); uint16_t vec_size = 16 / static_cast(sizeof(cpp_type)); vec_size = std::min(vec_size, max_vec_size); // Here we purposely omit vec8 for 1-byte data because of a bug in NVCC // that causes some numerical mismatches with uint8 on sm80 and sm90. // TODO: Revisit this after CUDA 12.8 update. cudaDeviceProp* p = at::cuda::getDeviceProperties(stream.device().index()); const int computeCapability = p->major * 10 + p->minor; if (computeCapability != 90 && computeCapability != 100) { vec_size = std::min(vec_size, 4); } if constexpr (sizeof(cpp_type) < 2) { vec_size = std::min(vec_size, 4); } int tws = elems_per_thread(); #endif int bws = tws * num_threads(); int64_t grid = (N + bws - 1) / bws; switch (vec_size) { #ifdef USE_ROCM case 16: vectorized_elementwise_kernel<16, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; #endif case 8: vectorized_elementwise_kernel<8, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 4: vectorized_elementwise_kernel<4, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: vectorized_elementwise_kernel<2, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 1: { auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); int64_t grid_unrolled = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); unrolled_elementwise_kernel <<>>( N, f, data, input_calc, output_calc, loader, storer); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); } } #ifdef USE_ROCM template < int vec_size, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, typename OutputType, typename... InputTypes> C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads()) __global__ void vectorized_templated_elementwise_kernel( int N, func_t f, array_t data, inp_calc_t inp_calc, out_calc_t out_calc, loader_t loader, storer_t storer) { int remaining = N - vectorized_templated_config::block_work_size() * (gridDim.x - blockIdx.x - 1); constexpr bool reverted_idx = true; if (remaining < vectorized_templated_config::block_work_size()) { // if this block handles // the reminder, // just do a naive unrolled loop auto policy = memory::policies::unroll_base< vectorized_templated_config::num_threads(), array_t, inp_calc_t, out_calc_t, loader_t, storer_t, vectorized_templated_config::elems_per_thread()>( data, remaining, inp_calc, out_calc, loader, storer); elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use // vectorized memory access auto policy = memory::policies::vectorized_templated< vec_size, array_t, vectorized_templated_config::elems_per_thread(), vectorized_templated_config::num_threads(), OutputType, InputTypes...>(data); elementwise_kernel_helper(f, policy); } } // This function assume trivial 1d and supports template specialization // to avoid dynamic casting. // Input vectorization size is based on runtime information, i.e. // the actual data types of the input and output tensor and cannot // be determined using the functor type, as in regular non-templated // vectorized kernels. The caller is in charge of selecting the correct input // vectorization length. template < typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t, typename OutputType, typename... InputTypes> static inline void launch_vectorized_templated_kernel( int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) / vectorized_templated_config::block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); int vec_size = memory::can_vectorize_up_to(data); switch (vec_size) { case 8: vectorized_templated_elementwise_kernel< 8, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...> <<>>( N, f, data, ic, oc, l, s); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 4: vectorized_templated_elementwise_kernel< 4, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...> <<>>( N, f, data, ic, oc, l, s); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: vectorized_templated_elementwise_kernel< 2, func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, OutputType, InputTypes...> <<>>( N, f, data, ic, oc, l, s); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; default: // vector size 1 is not handled as part of vectorize_templated kernel TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); } } #endif template < typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t> static inline void launch_unrolled_kernel( int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); int64_t grid = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); unrolled_elementwise_kernel <<>>(N, f, data, ic, oc, l, s); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template C10_LAUNCH_BOUNDS_2(nt, 4) __global__ void elementwise_kernel(int N, func_t f) { int tid = threadIdx.x; int nv = nt * vt; int idx = nv * blockIdx.x + tid; #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { f(idx); idx += nt; } } } template static void launch_legacy_kernel(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { return; } dim3 block(nt); dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f); C10_CUDA_KERNEL_LAUNCH_CHECK(); } #ifdef USE_ROCM template C10_LAUNCH_BOUNDS_2(nt, 4) __global__ void elementwise_kernel_manual_unroll(int N, func_t f) { int tid = threadIdx.x; constexpr int nv = nt * vt; int idx = nv * blockIdx.x + tid; if ((idx + nt*(vt-1)) < N) { f(idx, true); } else { #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { f(idx, false); idx += nt; } } } } template static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { return; } dim3 block(nt); dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel_manual_unroll<<>>(N, f); C10_CUDA_KERNEL_LAUNCH_CHECK(); } #endif template C10_HOST_DEVICE typename traits::result_type invoke_impl( const func_t& f, char* const C10_RESTRICT data[], const index_t strides[], int i, std::index_sequence) { (void)strides; (void)i; return f(c10::load::type>( data[INDEX] + i * strides[INDEX])...); } template < typename func_t, typename index_t, typename traits = function_traits> C10_HOST_DEVICE typename traits::result_type invoke( const func_t& f, char* const C10_RESTRICT data[], const index_t strides[], int i) { using Indices = std::make_index_sequence; return invoke_impl(f, data, strides, i, Indices{}); } template C10_HOST_DEVICE typename traits::result_type invoke_impl( const func_t& f, char* const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i, std::index_sequence) { (void)strides; (void)i; return f(c10::fetch_and_cast::type>( dtypes[I], data[I] + i * strides[I])...); } template < typename func_t, typename index_t, typename traits = function_traits> C10_HOST_DEVICE typename traits::result_type invoke( const func_t& f, char* const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) { using Indices = std::make_index_sequence; return invoke_impl(f, data, strides, dtypes, i, Indices{}); } template void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using arg0_t = typename traits::result_type; constexpr int ntensors = traits::arity + 1; TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); std::array data; for (int i = 0; i < ntensors; i++) { data[i] = (char*)iter.data_ptr(i); } int64_t numel = iter.numel(); bool contiguous = iter.is_contiguous(); if (contiguous) { return launch_vectorized_kernel(numel, f, data); } auto offset_calc = ::make_offset_calculator(iter); #ifndef USE_ROCM constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { auto offsets = offset_calc.get(idx); arg0_t* out = (arg0_t*)(data[0] + offsets[0]); *out = invoke(f, &data[1], &offsets[1], 1); }); #else constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8; constexpr int grp_sz = 128; launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { if (unrl) { if constexpr (unroll_factor == 4) { auto offsets0 = offset_calc.get(idx); auto offsets1 = offset_calc.get(idx+grp_sz); auto offsets2 = offset_calc.get(idx+grp_sz*2); auto offsets3 = offset_calc.get(idx+grp_sz*3); arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); *out0 = tmp0; *out1 = tmp1; *out2 = tmp2; *out3 = tmp3; } else { auto offsets0 = offset_calc.get(idx); auto offsets1 = offset_calc.get(idx+grp_sz); auto offsets2 = offset_calc.get(idx+grp_sz*2); auto offsets3 = offset_calc.get(idx+grp_sz*3); auto offsets4 = offset_calc.get(idx+grp_sz*4); auto offsets5 = offset_calc.get(idx+grp_sz*5); auto offsets6 = offset_calc.get(idx+grp_sz*6); auto offsets7 = offset_calc.get(idx+grp_sz*7); arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]); arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]); arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]); arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]); arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]); arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]); arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]); arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]); auto tmp0 = invoke(f, &data[1], &offsets0[1], 1); auto tmp1 = invoke(f, &data[1], &offsets1[1], 1); auto tmp2 = invoke(f, &data[1], &offsets2[1], 1); auto tmp3 = invoke(f, &data[1], &offsets3[1], 1); auto tmp4 = invoke(f, &data[1], &offsets4[1], 1); auto tmp5 = invoke(f, &data[1], &offsets5[1], 1); auto tmp6 = invoke(f, &data[1], &offsets6[1], 1); auto tmp7 = invoke(f, &data[1], &offsets7[1], 1); *out0 = tmp0; *out1 = tmp1; *out2 = tmp2; *out3 = tmp3; *out4 = tmp4; *out5 = tmp5; *out6 = tmp6; *out7 = tmp7; } } else { auto offsets = offset_calc.get(idx); arg0_t* out = (arg0_t*)(data[0] + offsets[0]); *out = invoke(f, &data[1], &offsets[1], 1); } }); #endif } #ifdef USE_ROCM namespace { template < typename TupleLike, typename FirstParamTy, typename SecondParamTy, size_t arity, size_t arg_num = 0> struct check_binary_functor_types_for_specialization { constexpr static inline bool check() { if constexpr (arity != 2) return false; if constexpr (arg_num == 0) { using SelectedType = std::tuple_element_t; if constexpr (std::is_same_v) return check_binary_functor_types_for_specialization< TupleLike, FirstParamTy, SecondParamTy, arity, arg_num + 1>::check(); } else if constexpr (arg_num == 1) { using SelectedType2 = std::tuple_element_t; if constexpr (std::is_same_v) return check_binary_functor_types_for_specialization< TupleLike, FirstParamTy, SecondParamTy, arity, arg_num + 1>::check(); } return false; } }; // Bottom case: if we got this far, assume correct type matching except // when there are no arguments (arity == 0). template < typename TupleLike, typename FirstParamTy, typename SecondParamTy, size_t arity> struct check_binary_functor_types_for_specialization< TupleLike, FirstParamTy, SecondParamTy, arity, arity> { constexpr static inline bool check() { if constexpr (arity != 0) return true; return false; } }; template struct check_binary_functor_types_for_specialization< TupleLike, FirstParamTy, SecondParamTy, 0, 0> { constexpr static inline bool check() { return false; } }; // The following is a list of type specializations for vectorized_templated // elementwise kernel. The three types refer to runtime types of the output // tensor, first tensor argument, and the second tensor argument used for a // binary functor. constexpr std::array rt_binary_specializations = { std::array( {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}), std::array( {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}), std::array( {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}), std::array( {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}), std::array( {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value}), std::array( {c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value, c10::CppTypeToScalarType::value})}; bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) { if (iter.ninputs() != 2) return false; for (auto spec : rt_binary_specializations) if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] && iter.input_dtype(1) == spec[2]) return true; return false; } template struct type_specialized_kernel_launcher { template < typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t> static void apply( ScalarType ret_t, ScalarType arg0_t, ScalarType arg1_t, int64_t numel, func_t f, array_t data, inp_calc_t input_offset_calculator, out_calc_t output_offset_calculator, loader_t loader, storer_t storer) { if (ret_t == rt_binary_specializations[arg_index][0] && arg0_t == rt_binary_specializations[arg_index][1] && arg1_t == rt_binary_specializations[arg_index][2]) launch_vectorized_templated_kernel< func_t, array_t, inp_calc_t, out_calc_t, loader_t, storer_t, decltype(c10::impl::ScalarTypeToCPPType< rt_binary_specializations[arg_index][0]>::t), decltype(c10::impl::ScalarTypeToCPPType< rt_binary_specializations[arg_index][1]>::t), decltype(c10::impl::ScalarTypeToCPPType< rt_binary_specializations[arg_index][2]>::t)>( numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer); } }; } // namespace #endif template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { if (!needs_dynamic_casting::check(iter)) { return gpu_kernel_impl_nocast(iter, f); } using traits = function_traits; using arg0_t = typename traits::result_type; constexpr int ntensors = traits::arity + 1; TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); std::array data; for (int i = 0; i < ntensors; i++) { data[i] = (char*)iter.data_ptr(i); } int64_t numel = iter.numel(); bool contiguous = iter.is_contiguous(); if (contiguous) { #ifdef USE_ROCM // Attempt to call specialized vectorized elementwise kernel // that enables interleaving. if (check_binary_rt_types_for_specialization(iter) && memory::can_vectorize_up_to(data) > 1) { // constexpr to reduce the amount of kernels generated for // vectorized templated elementwise and limit which functors are actually // applied to the load and store at compile time. using func_tuple = typename traits::ArgsTuple; if constexpr ( std::is_same_v && traits::arity == 2 && check_binary_functor_types_for_specialization< func_tuple, float, float, traits::arity, /*arg_num=*/0>::check()) { // If we got here, we know we are in one of the specialized cases. We // need to translate the runtime type to a statically known type. This // is effectively hoisting to the host the switch over runtime type in // the kernel in fetch_and_cast. Loader, storer, offset calculators are // only needed for the reminder loop. auto input_offset_calculator = TrivialOffsetCalculator(); auto output_offset_calculator = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithCast(iter); auto storer = memory::StoreWithCast<1>(iter); memory::detail::static_unroll< type_specialized_kernel_launcher, rt_binary_specializations.size()>:: with_args( iter.dtype(0), iter.input_dtype(0), iter.input_dtype(1), numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer); return; } } std::array dtypes; auto inner_strides = iter.get_inner_strides(); std::array strides; for (int i = 0; i < ntensors; i++) { dtypes[i] = iter.dtype(i); strides[i] = inner_strides[i]; } constexpr int grp_sz = 128; launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { if (unrl) { void* out0 = data[0] + strides[0] * idx; void* out1 = data[0] + strides[0] * (idx + grp_sz); void* out2 = data[0] + strides[0] * (idx + grp_sz * 2); void* out3 = data[0] + strides[0] * (idx + grp_sz * 3); arg0_t result0 = invoke(f, &data[1], &strides[1], &dtypes[1], idx); arg0_t result1 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz)); arg0_t result2 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 2)); arg0_t result3 = invoke(f, &data[1], &strides[1], &dtypes[1], (idx + grp_sz * 3)); c10::cast_and_store(dtypes[0], out0, result0); c10::cast_and_store(dtypes[0], out1, result1); c10::cast_and_store(dtypes[0], out2, result2); c10::cast_and_store(dtypes[0], out3, result3); } else { void* out = data[0] + strides[0] * idx; arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx); c10::cast_and_store(dtypes[0], out, result); } }); #else auto loader = memory::LoadWithCast(iter); auto storer = memory::StoreWithCast<1>(iter); auto input_offset_calculator = TrivialOffsetCalculator(); auto output_offset_calculator = TrivialOffsetCalculator<1>(); launch_unrolled_kernel( numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer); #endif } else { std::array dtypes; for (int i = 0; i < ntensors; i++) { dtypes[i] = iter.dtype(i); } auto offset_calc = ::make_offset_calculator(iter); launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { auto offsets = offset_calc.get(idx); void* out = data[0] + offsets[0]; arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); c10::cast_and_store(dtypes[0], out, result); }); } } } // namespace at::native