/* * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. * * NVIDIA CORPORATION and its licensors retain all intellectual property * and proprietary rights in and to this software, related documentation * and any modifications thereto. Any use, reproduction, disclosure or * distribution of this software and related documentation without an express * license agreement from NVIDIA CORPORATION is strictly prohibited. */ #if !defined(CUSPARSELT_HEADER_) #define CUSPARSELT_HEADER_ #include "cusparse.h" // cusparseStatus_t #include // size_t #include // cudaStream_t #include // cudaDataType #include // uint8_t //############################################################################## //# CUSPARSELT VERSION INFORMATION //############################################################################## #define CUSPARSELT_VER_MAJOR 0 #define CUSPARSELT_VER_MINOR 7 #define CUSPARSELT_VER_PATCH 1 #define CUSPARSELT_VER_BUILD 0 #define CUSPARSELT_VERSION (CUSPARSELT_VER_MAJOR * 1000 + \ CUSPARSELT_VER_MINOR * 100 + \ CUSPARSELT_VER_PATCH) // ############################################################################# // # MACRO // ############################################################################# #if !defined(CUSPARSELT_API) # if defined(_WIN32) # define CUSPARSELT_API __stdcall # else # define CUSPARSELT_API # endif #endif //------------------------------------------------------------------------------ #if defined(__cplusplus) extern "C" { #endif // defined(__cplusplus) //############################################################################## //# OPAQUE DATA STRUCTURES //############################################################################## typedef struct { uint8_t data[1024]; } cusparseLtHandle_t; typedef struct { uint8_t data[1024]; } cusparseLtMatDescriptor_t; typedef struct { uint8_t data[1024]; } cusparseLtMatmulDescriptor_t; typedef struct { uint8_t data[1024]; } cusparseLtMatmulAlgSelection_t; typedef struct { uint8_t data[1024]; } cusparseLtMatmulPlan_t; const char* CUSPARSELT_API cusparseLtGetErrorName(cusparseStatus_t status); const char* CUSPARSELT_API cusparseLtGetErrorString(cusparseStatus_t status); //############################################################################## //# INITIALIZATION, DESTROY //############################################################################## cusparseStatus_t CUSPARSELT_API cusparseLtInit(cusparseLtHandle_t* handle); cusparseStatus_t CUSPARSELT_API cusparseLtDestroy(const cusparseLtHandle_t* handle); cusparseStatus_t CUSPARSELT_API cusparseLtGetVersion(const cusparseLtHandle_t* handle, int* version); cusparseStatus_t CUSPARSELT_API cusparseLtGetProperty(libraryPropertyType propertyType, int* value); //############################################################################## //# MATRIX DESCRIPTOR //############################################################################## // Dense Matrix cusparseStatus_t CUSPARSELT_API cusparseLtDenseDescriptorInit(const cusparseLtHandle_t* handle, cusparseLtMatDescriptor_t* matDescr, int64_t rows, int64_t cols, int64_t ld, uint32_t alignment, cudaDataType valueType, cusparseOrder_t order); //------------------------------------------------------------------------------ // Structured Matrix typedef enum { CUSPARSELT_SPARSITY_50_PERCENT } cusparseLtSparsity_t; cusparseStatus_t CUSPARSELT_API cusparseLtStructuredDescriptorInit(const cusparseLtHandle_t* handle, cusparseLtMatDescriptor_t* matDescr, int64_t rows, int64_t cols, int64_t ld, uint32_t alignment, cudaDataType valueType, cusparseOrder_t order, cusparseLtSparsity_t sparsity); cusparseStatus_t CUSPARSELT_API cusparseLtMatDescriptorDestroy(const cusparseLtMatDescriptor_t* matDescr); //------------------------------------------------------------------------------ typedef enum { CUSPARSELT_MAT_NUM_BATCHES, // READ/WRITE CUSPARSELT_MAT_BATCH_STRIDE // READ/WRITE } cusparseLtMatDescAttribute_t; cusparseStatus_t CUSPARSELT_API cusparseLtMatDescSetAttribute(const cusparseLtHandle_t* handle, cusparseLtMatDescriptor_t* matmulDescr, cusparseLtMatDescAttribute_t matAttribute, const void* data, size_t dataSize); cusparseStatus_t CUSPARSELT_API cusparseLtMatDescGetAttribute(const cusparseLtHandle_t* handle, const cusparseLtMatDescriptor_t* matmulDescr, cusparseLtMatDescAttribute_t matAttribute, void* data, size_t dataSize); //############################################################################## //# MATMUL DESCRIPTOR //############################################################################## typedef enum { CUSPARSE_COMPUTE_32I, CUSPARSE_COMPUTE_16F, CUSPARSE_COMPUTE_32F } cusparseComputeType; cusparseStatus_t CUSPARSELT_API cusparseLtMatmulDescriptorInit(const cusparseLtHandle_t* handle, cusparseLtMatmulDescriptor_t* matmulDescr, cusparseOperation_t opA, cusparseOperation_t opB, const cusparseLtMatDescriptor_t* matA, const cusparseLtMatDescriptor_t* matB, const cusparseLtMatDescriptor_t* matC, const cusparseLtMatDescriptor_t* matD, cusparseComputeType computeType); //------------------------------------------------------------------------------ typedef enum { CUSPARSELT_MATMUL_ACTIVATION_RELU, // READ/WRITE CUSPARSELT_MATMUL_ACTIVATION_RELU_UPPERBOUND, // READ/WRITE CUSPARSELT_MATMUL_ACTIVATION_RELU_THRESHOLD, // READ/WRITE CUSPARSELT_MATMUL_ACTIVATION_GELU, // READ/WRITE CUSPARSELT_MATMUL_ACTIVATION_GELU_SCALING, // READ/WRITE CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING, // READ/WRITE CUSPARSELT_MATMUL_BETA_VECTOR_SCALING, // READ/WRITE CUSPARSELT_MATMUL_BIAS_STRIDE, // READ/WRITE CUSPARSELT_MATMUL_BIAS_POINTER, // READ/WRITE CUSPARSELT_MATMUL_SPARSE_MAT_POINTER, // READ/WRITE // CUSPARSELT_MATMUL_A_SCALE_MODE, // READ/WRITE CUSPARSELT_MATMUL_B_SCALE_MODE, // READ/WRITE CUSPARSELT_MATMUL_C_SCALE_MODE, // READ/WRITE CUSPARSELT_MATMUL_D_SCALE_MODE, // READ/WRITE CUSPARSELT_MATMUL_D_OUT_SCALE_MODE, // READ/WRITE CUSPARSELT_MATMUL_A_SCALE_POINTER, CUSPARSELT_MATMUL_B_SCALE_POINTER, CUSPARSELT_MATMUL_C_SCALE_POINTER, CUSPARSELT_MATMUL_D_SCALE_POINTER, CUSPARSELT_MATMUL_D_OUT_SCALE_POINTER, } cusparseLtMatmulDescAttribute_t; typedef enum { CUSPARSELT_MATMUL_SCALE_NONE, CUSPARSELT_MATMUL_MATRIX_SCALE_SCALAR_32F, CUSPARSELT_MATMUL_MATRIX_SCALE_VEC32_UE4M3, CUSPARSELT_MATMUL_MATRIX_SCALE_VEC64_UE8M0, } cusparseLtMatmulMatrixScale_t; cusparseStatus_t CUSPARSELT_API cusparseLtMatmulDescSetAttribute(const cusparseLtHandle_t* handle, cusparseLtMatmulDescriptor_t* matmulDescr, cusparseLtMatmulDescAttribute_t matmulAttribute, const void* data, size_t dataSize); cusparseStatus_t CUSPARSELT_API cusparseLtMatmulDescGetAttribute( const cusparseLtHandle_t* handle, const cusparseLtMatmulDescriptor_t* matmulDescr, cusparseLtMatmulDescAttribute_t matmulAttribute, void* data, size_t dataSize); //############################################################################## //# ALGORITHM SELECTION //############################################################################## typedef enum { CUSPARSELT_MATMUL_ALG_DEFAULT } cusparseLtMatmulAlg_t; cusparseStatus_t CUSPARSELT_API cusparseLtMatmulAlgSelectionInit( const cusparseLtHandle_t* handle, cusparseLtMatmulAlgSelection_t* algSelection, const cusparseLtMatmulDescriptor_t* matmulDescr, cusparseLtMatmulAlg_t alg); typedef enum { CUSPARSELT_MATMUL_ALG_CONFIG_ID, // READ/WRITE CUSPARSELT_MATMUL_ALG_CONFIG_MAX_ID, // READ-ONLY CUSPARSELT_MATMUL_SEARCH_ITERATIONS, // READ/WRITE CUSPARSELT_MATMUL_SPLIT_K, // READ/WRITE CUSPARSELT_MATMUL_SPLIT_K_MODE, // READ/WRITE CUSPARSELT_MATMUL_SPLIT_K_BUFFERS // READ/WRITE } cusparseLtMatmulAlgAttribute_t; typedef enum { CUSPARSELT_INVALID_MODE = 0, CUSPARSELT_SPLIT_K_MODE_ONE_KERNEL = 1, CUSPARSELT_SPLIT_K_MODE_TWO_KERNELS = 2, CUSPARSELT_HEURISTIC, CUSPARSELT_DATAPARALLEL, CUSPARSELT_SPLITK, CUSPARSELT_STREAMK } cusparseLtSplitKMode_t; cusparseStatus_t CUSPARSELT_API cusparseLtMatmulAlgSetAttribute(const cusparseLtHandle_t* handle, cusparseLtMatmulAlgSelection_t* algSelection, cusparseLtMatmulAlgAttribute_t attribute, const void* data, size_t dataSize); cusparseStatus_t CUSPARSELT_API cusparseLtMatmulAlgGetAttribute( const cusparseLtHandle_t* handle, const cusparseLtMatmulAlgSelection_t* algSelection, cusparseLtMatmulAlgAttribute_t attribute, void* data, size_t dataSize); //############################################################################## //# MATMUL PLAN //############################################################################## cusparseStatus_t CUSPARSELT_API cusparseLtMatmulGetWorkspace( const cusparseLtHandle_t* handle, const cusparseLtMatmulPlan_t* plan, size_t* workspaceSize); cusparseStatus_t CUSPARSELT_API cusparseLtMatmulPlanInit(const cusparseLtHandle_t* handle, cusparseLtMatmulPlan_t* plan, const cusparseLtMatmulDescriptor_t* matmulDescr, const cusparseLtMatmulAlgSelection_t* algSelection); cusparseStatus_t CUSPARSELT_API cusparseLtMatmulPlanDestroy(const cusparseLtMatmulPlan_t* plan); //############################################################################## //# MATMUL EXECUTION //############################################################################## cusparseStatus_t CUSPARSELT_API cusparseLtMatmul(const cusparseLtHandle_t* handle, const cusparseLtMatmulPlan_t* plan, const void* alpha, const void* d_A, const void* d_B, const void* beta, const void* d_C, void* d_D, void* workspace, cudaStream_t* streams, int32_t numStreams); cusparseStatus_t CUSPARSELT_API cusparseLtMatmulSearch(const cusparseLtHandle_t* handle, cusparseLtMatmulPlan_t* plan, const void* alpha, const void* d_A, const void* d_B, const void* beta, const void* d_C, void* d_D, void* workspace, // void* device_buf, cudaStream_t* streams, int32_t numStreams); //############################################################################## //# HELPER ROUTINES //############################################################################## // PRUNING typedef enum { CUSPARSELT_PRUNE_SPMMA_TILE = 0, CUSPARSELT_PRUNE_SPMMA_STRIP = 1 } cusparseLtPruneAlg_t; cusparseStatus_t CUSPARSELT_API cusparseLtSpMMAPrune(const cusparseLtHandle_t* handle, const cusparseLtMatmulDescriptor_t* matmulDescr, const void* d_in, void* d_out, cusparseLtPruneAlg_t pruneAlg, cudaStream_t stream); cusparseStatus_t CUSPARSELT_API cusparseLtSpMMAPruneCheck(const cusparseLtHandle_t* handle, const cusparseLtMatmulDescriptor_t* matmulDescr, const void* d_in, int* valid, cudaStream_t stream); cusparseStatus_t CUSPARSELT_API cusparseLtSpMMAPrune2(const cusparseLtHandle_t* handle, const cusparseLtMatDescriptor_t* sparseMatDescr, int isSparseA, cusparseOperation_t op, const void* d_in, void* d_out, cusparseLtPruneAlg_t pruneAlg, cudaStream_t stream); cusparseStatus_t CUSPARSELT_API cusparseLtSpMMAPruneCheck2(const cusparseLtHandle_t* handle, const cusparseLtMatDescriptor_t* sparseMatDescr, int isSparseA, cusparseOperation_t op, const void* d_in, int* d_valid, cudaStream_t stream); //------------------------------------------------------------------------------ // COMPRESSION cusparseStatus_t CUSPARSELT_API cusparseLtSpMMACompressedSize( const cusparseLtHandle_t* handle, const cusparseLtMatmulPlan_t* plan, size_t* compressedSize, size_t* compressedBufferSize); cusparseStatus_t CUSPARSELT_API cusparseLtSpMMACompress(const cusparseLtHandle_t* handle, const cusparseLtMatmulPlan_t* plan, const void* d_dense, void* d_compressed, void* d_compressed_buffer, cudaStream_t stream); cusparseStatus_t CUSPARSELT_API cusparseLtSpMMACompressedSize2( const cusparseLtHandle_t* handle, const cusparseLtMatDescriptor_t* sparseMatDescr, size_t* compressedSize, size_t* compressedBufferSize); cusparseStatus_t CUSPARSELT_API cusparseLtSpMMACompress2(const cusparseLtHandle_t* handle, const cusparseLtMatDescriptor_t* sparseMatDescr, int isSparseA, cusparseOperation_t op, const void* d_dense, void* d_compressed, void* d_compressed_buffer, cudaStream_t stream); //============================================================================== //============================================================================== #if defined(__cplusplus) } #endif // defined(__cplusplus) #endif // !defined(CUSPARSELT_HEADER_)