// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include #include #include #include #include #include #include #include #include #define TORCH_HIPBLASLT_CHECK(EXPR) \ do { \ hipblasStatus_t __err = EXPR; \ TORCH_CHECK(__err == HIPBLAS_STATUS_SUCCESS, \ "hipblaslt error: ", \ hipblasStatusToString(__err), \ " when calling `" #EXPR "`"); \ } while (0) namespace at::cuda::tunable { template constexpr hipDataType HipDataTypeFor(); template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_32F; } template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_16F; } template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_16BF; } template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_64F; } template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_8F_E4M3_FNUZ; } template <> constexpr hipDataType HipDataTypeFor() { return HIP_R_8F_E5M2_FNUZ; } // This code is instantiated regardless of ROCm version. // Prior to ROCm 6.3, we hard-code the known enum values. template <> constexpr hipDataType HipDataTypeFor() { #if ROCM_VERSION >= 60300 return HIP_R_8F_E4M3; #else return static_cast(28); #endif } template <> constexpr hipDataType HipDataTypeFor() { #if ROCM_VERSION >= 60300 return HIP_R_8F_E5M2; #else return static_cast(29); #endif } // This type is not intended for matrix types but rather a scale factor. // Return a dummy value to satisfy linker. template <> constexpr hipDataType HipDataTypeFor() { return static_cast(500); } template <> constexpr hipDataType HipDataTypeFor() { #if ROCM_VERSION >= 70000 return HIP_R_4F_E2M1; #else return static_cast(33); #endif } template int GetBatchFromParams(const GemmParams* params) { return 1; } template int GetBatchFromParams(const GemmAndBiasParams* params) { return 1; } template int GetBatchFromParams(const GemmStridedBatchedParams* params) { return params->batch; } template int GetBatchFromParams(const ScaledGemmParams* params) { return 1; } template int GetStrideAFromParams(const GemmParams* params) { return 1; } template int GetStrideAFromParams(const GemmAndBiasParams* params) { return 1; } template int GetStrideAFromParams(const GemmStridedBatchedParams* params) { return params->stride_a; } template int GetStrideAFromParams(const ScaledGemmParams* params) { return 1; } template int GetStrideBFromParams(const GemmParams* params) { return 1; } template int GetStrideBFromParams(const GemmAndBiasParams* params) { return 1; } template int GetStrideBFromParams(const GemmStridedBatchedParams* params) { return params->stride_b; } template int GetStrideBFromParams(const ScaledGemmParams* params) { return 1; } template int GetStrideCFromParams(const GemmParams* params) { return 1; } template int GetStrideCFromParams(const GemmAndBiasParams* params) { return 1; } template int GetStrideCFromParams(const GemmStridedBatchedParams* params) { return params->stride_c; } template int GetStrideCFromParams(const ScaledGemmParams* params) { return 1; } template float GetAlphaFromParams(const GemmParams* params) { return params->alpha; } template float GetAlphaFromParams(const GemmAndBiasParams* params) { return params->alpha; } template float GetAlphaFromParams(const GemmStridedBatchedParams* params) { return params->alpha; } template float GetAlphaFromParams(const ScaledGemmParams* params) { return 1.0; } template float GetBetaFromParams(const GemmParams* params) { return params->beta; } template float GetBetaFromParams(const GemmAndBiasParams* params) { return 0.0; } template float GetBetaFromParams(const GemmStridedBatchedParams* params) { return params->beta; } template float GetBetaFromParams(const ScaledGemmParams* params) { return 0.0; } template ScalingType GetAScalingTypeFromParams(const GemmParams* params) { return ScalingType::TensorWise; } template ScalingType GetBScalingTypeFromParams(const GemmParams* params) { return ScalingType::TensorWise; } template ScalingType GetAScalingTypeFromParams(const GemmAndBiasParams* params) { return ScalingType::TensorWise; } template ScalingType GetBScalingTypeFromParams(const GemmAndBiasParams* params) { return ScalingType::TensorWise; } template ScalingType GetAScalingTypeFromParams(const GemmStridedBatchedParams* params) { return ScalingType::TensorWise; } template ScalingType GetBScalingTypeFromParams(const GemmStridedBatchedParams* params) { return ScalingType::TensorWise; } template ScalingType GetAScalingTypeFromParams(const ScaledGemmParams* params) { return params->a_scaling_type; } template ScalingType GetBScalingTypeFromParams(const ScaledGemmParams* params) { return params->b_scaling_type; } template const void* GetAScalePointerFromParams(const GemmParams* params) { return nullptr; } template const void* GetAScalePointerFromParams(const GemmAndBiasParams* params) { return nullptr; } template const void* GetAScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; } template const void* GetAScalePointerFromParams(const ScaledGemmParams* params) { return params->a_scale_ptr; } template const void* GetBScalePointerFromParams(const GemmParams* params) { return nullptr; } template const void* GetBScalePointerFromParams(const GemmAndBiasParams* params) { return nullptr; } template const void* GetBScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; } template const void* GetBScalePointerFromParams(const ScaledGemmParams* params) { return params->b_scale_ptr; } template const void* GetDScalePointerFromParams(const GemmParams* params) { return nullptr; } template const void* GetDScalePointerFromParams(const GemmAndBiasParams* params) { return nullptr; } template const void* GetDScalePointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; } template const void* GetDScalePointerFromParams(const ScaledGemmParams* params) { return params->c_scale_ptr; } template const void* GetBiasPointerFromParams(const GemmParams* params) { return nullptr; } template const void* GetBiasPointerFromParams(const GemmAndBiasParams* params) { return params->bias; } template const void* GetBiasPointerFromParams(const GemmStridedBatchedParams* params) { return nullptr; } template const void* GetBiasPointerFromParams(const ScaledGemmParams* params) { return params->bias_ptr; } template hipDataType GetBiasTypeFromParams(const GemmParams* params) { return HIP_R_32F; } template hipDataType GetBiasTypeFromParams(const GemmAndBiasParams* params) { return HipDataTypeFor(); } template hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams* params) { return HIP_R_32F; } template hipDataType GetBiasTypeFromParams(const ScaledGemmParams* params) { return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype); } template at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmParams* params) { return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; } template at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmAndBiasParams* params) { return params->activation; } template at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const GemmStridedBatchedParams* params) { return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; } template at::cuda::blas::GEMMAndBiasActivationEpilogue GetActivationFromParams(const ScaledGemmParams* params) { return at::cuda::blas::GEMMAndBiasActivationEpilogue::None; } static hipblasOperation_t _hipblasOpFromChar(char op) { switch (op) { case 'n': case 'N': return HIPBLAS_OP_N; case 't': case 'T': return HIPBLAS_OP_T; case 'c': case 'C': return HIPBLAS_OP_C; } TORCH_CHECK(false, "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } static char _charFromhipblasOp(hipblasOperation_t op) { switch (op) { case HIPBLAS_OP_N: return 'N'; case HIPBLAS_OP_T: return 'T'; case HIPBLAS_OP_C: return 'C'; } TORCH_CHECK(false, "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`"); } static hipblasOperation_t MapLayoutToHipBlasLt(BlasOp layout) { if (layout == BlasOp::N) { return HIPBLAS_OP_N; } return HIPBLAS_OP_T; } template struct HipBlasLtDeleter { void operator()(T* x) { if (x != nullptr) { TORCH_CUDABLAS_CHECK(destructor(x)); } } }; template class HipBlasLtDescriptor { public: T* descriptor() const { return descriptor_.get(); } T* descriptor() { return descriptor_.get(); } protected: std::unique_ptr> descriptor_; }; class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< hipblasLtMatmulDescOpaque_t, &hipblasLtMatmulDescDestroy> { public: HipBlasLtMatmulDescriptor( hipblasComputeType_t compute_type, hipDataType scale_type) { hipblasLtMatmulDesc_t raw_descriptor = nullptr; TORCH_HIPBLASLT_CHECK( hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); descriptor_.reset(raw_descriptor); } template inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T))); } }; template class HipblasltGemmOp : public Callable { public: HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {} TuningStatus Call(const ParamsT* params) override { hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); auto a_datatype = HipDataTypeFor(); auto b_datatype = HipDataTypeFor(); auto in_out_datatype = HipDataTypeFor(); auto opa = _hipblasOpFromChar(params->transa); auto opb = _hipblasOpFromChar(params->transb); TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen"); float alpha = GetAlphaFromParams(params); float beta = GetBetaFromParams(params); hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; if (opa == HIPBLAS_OP_N) { TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda)); } else { TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda)); } if (opb == HIPBLAS_OP_N) { TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb)); } else { TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb)); } TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc)); // specific to batched gemmm int batch = GetBatchFromParams(params); if (batch > 1) { int64_t stride_a = GetStrideAFromParams(params); int64_t stride_b = GetStrideBFromParams(params); int64_t stride_c = GetStrideCFromParams(params); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( mat_a, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a, sizeof(stride_a))); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( mat_b, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( mat_b, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b, sizeof(stride_b))); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( mat_c, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute( mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); } hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; if (at::globalContext().float32Precision("cuda", "matmul") == "tf32") { computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; } HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); // specific to scaled gemm const void* mat1_scale_ptr = GetAScalePointerFromParams(params); const void* mat2_scale_ptr = GetBScalePointerFromParams(params); const void* result_scale_ptr = GetDScalePointerFromParams(params); if (mat1_scale_ptr && mat2_scale_ptr) { hipblasLtMatmulDescAttributes_t a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; hipblasLtMatmulDescAttributes_t b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; if (GetAScalingTypeFromParams(params) == ScalingType::RowWise) { #if defined(HIPBLASLT_OUTER_VEC) matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); #elif defined(HIPBLASLT_VEC_EXT) a_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; #endif } if (GetBScalingTypeFromParams(params) == ScalingType::RowWise) { #if defined(HIPBLASLT_OUTER_VEC) matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_MODE, HIPBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F); #elif defined(HIPBLASLT_VEC_EXT) b_scale_ptr_desc = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; #endif } matmul.setAttribute(a_scale_ptr_desc, mat1_scale_ptr); matmul.setAttribute(b_scale_ptr_desc, mat2_scale_ptr); } if (result_scale_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } const void* bias_ptr = GetBiasPointerFromParams(params); auto bias_datatype = GetBiasTypeFromParams(params); if (bias_ptr) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype); auto activation = GetActivationFromParams(params); if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::RELU) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_RELU_BIAS); } else if (activation == at::cuda::blas::GEMMAndBiasActivationEpilogue::GELU) { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_GELU_BIAS); } else { matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); } } size_t workspace_size = at::cuda::getCUDABlasLtWorkspaceSize(); auto op_handle = at::cuda::getCurrentCUDABlasLtHandle(); size_t ret_workspace_size = 0; auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle, matmul.descriptor(), &alpha, mat_a, mat_b, &beta, mat_c, mat_c, algo_, ret_workspace_size); if (status == HIPBLAS_STATUS_SUCCESS) { if (ret_workspace_size >= workspace_size) { return FAIL; } } else { return FAIL; } void* workspace_buffer = at::cuda::getCUDABlasLtWorkspace(); TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle, matmul.descriptor(), &alpha, params->a, mat_a, params->b, mat_b, &beta, params->c, mat_c, params->c, mat_c, &algo_, workspace_buffer, workspace_size, at::cuda::getCurrentCUDAStream())); //TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul)); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a)); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b)); TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c)); return OK; } private: hipblasLtMatmulAlgo_t algo_; }; template auto GetHipBlasLtTypeStringAndOps() { hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout); hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout); auto a_datatype = HipDataTypeFor(); auto b_datatype = HipDataTypeFor(); auto in_out_datatype = HipDataTypeFor(); std::vector heuristic_result; #if ROCM_VERSION == 60400 // hipblaslt TT fp32 regression on ROCm 6.4, cannot use if ((a_datatype == HIP_R_32F || b_datatype == HIP_R_32F || in_out_datatype == HIP_R_32F) && (transa_outer == HIPBLAS_OP_T && transb_outer == HIPBLAS_OP_T)) { std::vector>>> ignore; return ignore; } #endif hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; if (at::globalContext().allowTF32CuBLAS()) { computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; } hipblasLtHandle_t handle; TORCH_HIPBLASLT_CHECK(hipblasLtCreate(&handle)); TORCH_HIPBLASLT_CHECK(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, transa_outer, transb_outer, a_datatype, b_datatype, in_out_datatype, in_out_datatype, computeType, heuristic_result)); TORCH_HIPBLASLT_CHECK(hipblasLtDestroy(handle)); int returned_algo_count = heuristic_result.size(); std::vector>>> ret; for (int i = 0; i < returned_algo_count; i++) { auto algo = heuristic_result[i].algo; int algo_index = hipblaslt_ext::getIndexFromAlgo(algo); auto callable = std::make_unique>(algo); std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%d", algo_index); ret.emplace_back(type_string, std::move(callable)); } return ret; } template auto GetHipBlasLtGemmTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } template auto GetHipBlasLtGemmAndBiasTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } template auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } template auto GetHipBlasLtScaledGemmTypeStringAndOps() { return GetHipBlasLtTypeStringAndOps>(); } #undef TORCH_HIPBLASLT_CHECK } // namespace at::cuda::tunable