#pragma once #include #include namespace at::cuda::detail { template struct enable_2x_kernel_for_sm89 : Kernel { template CUTLASS_DEVICE static void invoke(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 890 Kernel::invoke(std::forward(args)...); #endif } }; template struct enable_3x_kernel_for_sm9x : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 && __CUDA_ARCH__ < 1000 Kernel::operator()(std::forward(args)...); #endif } }; template struct enable_3x_kernel_for_sm10 : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200 Kernel::operator()(std::forward(args)...); #endif } }; template struct enable_3x_kernel_for_sm10_or_later : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 Kernel::operator()(std::forward(args)...); #endif } }; } // namespace at::cuda::detail