#pragma once #include #include namespace at::native { template inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); } #if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM) template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); #endif } // namespace at::native