/* * Copyright 1993-2022 NVIDIA Corporation. All rights reserved. * * NOTICE TO LICENSEE: * * This source code and/or documentation ("Licensed Deliverables") are * subject to NVIDIA intellectual property rights under U.S. and * international Copyright laws. * * These Licensed Deliverables contained herein is PROPRIETARY and * CONFIDENTIAL to NVIDIA and is being provided under the terms and * conditions of a form of NVIDIA software license agreement by and * between NVIDIA and Licensee ("License Agreement") or electronically * accepted by Licensee. Notwithstanding any terms or conditions to * the contrary in the License Agreement, reproduction or disclosure * of the Licensed Deliverables to any third party without the express * written consent of NVIDIA is prohibited. * * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE * OF THESE LICENSED DELIVERABLES. * * U.S. Government End Users. These Licensed Deliverables are a * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT * 1995), consisting of "commercial computer software" and "commercial * computer software documentation" as such terms are used in 48 * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government * only as a commercial end item. Consistent with 48 C.F.R.12.212 and * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all * U.S. Government End Users acquire the Licensed Deliverables with * only those rights set forth herein. * * Any use of the Licensed Deliverables in individual and commercial * software must include, in the user documentation and internal * comments to the code, the above Disclaimer and U.S. Government End * Users Notice. */ #pragma once #ifndef CUBLASAPI #ifdef __CUDACC__ #define CUBLASAPI __host__ __device__ #else #define CUBLASAPI #endif #endif #include #include #include #include #if defined(__cplusplus) extern "C" { #endif /* __cplusplus */ /** Opaque structure holding CUBLASLT context */ typedef struct cublasLtContext* cublasLtHandle_t; cublasStatus_t CUBLASWINAPI cublasLtCreate(cublasLtHandle_t* lightHandle); cublasStatus_t CUBLASWINAPI cublasLtDestroy(cublasLtHandle_t lightHandle); const char* CUBLASWINAPI cublasLtGetStatusName(cublasStatus_t status); const char* CUBLASWINAPI cublasLtGetStatusString(cublasStatus_t status); size_t CUBLASWINAPI cublasLtGetVersion(void); size_t CUBLASWINAPI cublasLtGetCudartVersion(void); cublasStatus_t CUBLASWINAPI cublasLtGetProperty(libraryPropertyType type, int* value); cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheGetCapacity(size_t* capacity); cublasStatus_t CUBLASWINAPI cublasLtHeuristicsCacheSetCapacity(size_t capacity); /** Restricts usage of CPU instructions (ISA) specified by the flags in the mask. * * Flags can be combined with bitwise OR(|) operator. Supported flags: * - 0x1 -- x86-64 AVX512 ISA * * Default mask: 0 (any applicable ISA is allowed). * * The function returns the previous value of the mask. * The function takes precedence over the environment variable CUBLASLT_DISABLE_CPU_INSTRUCTIONS_MASK. */ unsigned CUBLASWINAPI cublasLtDisableCpuInstructionsSetMask(unsigned mask); /** Semi-opaque descriptor for matrix memory layout */ typedef struct { uint64_t data[8]; } cublasLtMatrixLayoutOpaque_t; /** Opaque descriptor for matrix memory layout */ typedef cublasLtMatrixLayoutOpaque_t* cublasLtMatrixLayout_t; /** Semi-opaque algorithm descriptor (to avoid complicated alloc/free schemes) * * This structure can be trivially serialized and later restored for use with the same version of cuBLAS library to save * on selecting the right configuration again. */ typedef struct { uint64_t data[8]; } cublasLtMatmulAlgo_t; /** Semi-opaque descriptor for cublasLtMatmul() operation details */ typedef struct { uint64_t data[32]; } cublasLtMatmulDescOpaque_t; /** Opaque descriptor for cublasLtMatmul() operation details */ typedef cublasLtMatmulDescOpaque_t* cublasLtMatmulDesc_t; /** Semi-opaque descriptor for cublasLtMatrixTransform() operation details */ typedef struct { uint64_t data[8]; } cublasLtMatrixTransformDescOpaque_t; /** Opaque descriptor for cublasLtMatrixTransform() operation details */ typedef cublasLtMatrixTransformDescOpaque_t* cublasLtMatrixTransformDesc_t; /** Semi-opaque descriptor for cublasLtMatmulPreference() operation details */ typedef struct { uint64_t data[8]; } cublasLtMatmulPreferenceOpaque_t; /** Opaque descriptor for cublasLtMatmulAlgoGetHeuristic() configuration */ typedef cublasLtMatmulPreferenceOpaque_t* cublasLtMatmulPreference_t; /** Tile size (in C/D matrix Rows x Cols) * * General order of tile IDs is sorted by size first and by first dimension second. */ typedef enum { CUBLASLT_MATMUL_TILE_UNDEFINED = 0, CUBLASLT_MATMUL_TILE_8x8 = 1, CUBLASLT_MATMUL_TILE_8x16 = 2, CUBLASLT_MATMUL_TILE_16x8 = 3, CUBLASLT_MATMUL_TILE_8x32 = 4, CUBLASLT_MATMUL_TILE_16x16 = 5, CUBLASLT_MATMUL_TILE_32x8 = 6, CUBLASLT_MATMUL_TILE_8x64 = 7, CUBLASLT_MATMUL_TILE_16x32 = 8, CUBLASLT_MATMUL_TILE_32x16 = 9, CUBLASLT_MATMUL_TILE_64x8 = 10, CUBLASLT_MATMUL_TILE_32x32 = 11, CUBLASLT_MATMUL_TILE_32x64 = 12, CUBLASLT_MATMUL_TILE_64x32 = 13, CUBLASLT_MATMUL_TILE_32x128 = 14, CUBLASLT_MATMUL_TILE_64x64 = 15, CUBLASLT_MATMUL_TILE_128x32 = 16, CUBLASLT_MATMUL_TILE_64x128 = 17, CUBLASLT_MATMUL_TILE_128x64 = 18, CUBLASLT_MATMUL_TILE_64x256 = 19, CUBLASLT_MATMUL_TILE_128x128 = 20, CUBLASLT_MATMUL_TILE_256x64 = 21, CUBLASLT_MATMUL_TILE_64x512 = 22, CUBLASLT_MATMUL_TILE_128x256 = 23, CUBLASLT_MATMUL_TILE_256x128 = 24, CUBLASLT_MATMUL_TILE_512x64 = 25, CUBLASLT_MATMUL_TILE_64x96 = 26, CUBLASLT_MATMUL_TILE_96x64 = 27, CUBLASLT_MATMUL_TILE_96x128 = 28, CUBLASLT_MATMUL_TILE_128x160 = 29, CUBLASLT_MATMUL_TILE_160x128 = 30, CUBLASLT_MATMUL_TILE_192x128 = 31, CUBLASLT_MATMUL_TILE_128x192 = 32, CUBLASLT_MATMUL_TILE_128x96 = 33, CUBLASLT_MATMUL_TILE_32x256 = 34, CUBLASLT_MATMUL_TILE_256x32 = 35, CUBLASLT_MATMUL_TILE_8x128 = 36, CUBLASLT_MATMUL_TILE_8x192 = 37, CUBLASLT_MATMUL_TILE_8x256 = 38, CUBLASLT_MATMUL_TILE_8x320 = 39, CUBLASLT_MATMUL_TILE_8x384 = 40, CUBLASLT_MATMUL_TILE_8x448 = 41, CUBLASLT_MATMUL_TILE_8x512 = 42, CUBLASLT_MATMUL_TILE_8x576 = 43, CUBLASLT_MATMUL_TILE_8x640 = 44, CUBLASLT_MATMUL_TILE_8x704 = 45, CUBLASLT_MATMUL_TILE_8x768 = 46, CUBLASLT_MATMUL_TILE_16x64 = 47, CUBLASLT_MATMUL_TILE_16x128 = 48, CUBLASLT_MATMUL_TILE_16x192 = 49, CUBLASLT_MATMUL_TILE_16x256 = 50, CUBLASLT_MATMUL_TILE_16x320 = 51, CUBLASLT_MATMUL_TILE_16x384 = 52, CUBLASLT_MATMUL_TILE_16x448 = 53, CUBLASLT_MATMUL_TILE_16x512 = 54, CUBLASLT_MATMUL_TILE_16x576 = 55, CUBLASLT_MATMUL_TILE_16x640 = 56, CUBLASLT_MATMUL_TILE_16x704 = 57, CUBLASLT_MATMUL_TILE_16x768 = 58, CUBLASLT_MATMUL_TILE_24x64 = 59, CUBLASLT_MATMUL_TILE_24x128 = 60, CUBLASLT_MATMUL_TILE_24x192 = 61, CUBLASLT_MATMUL_TILE_24x256 = 62, CUBLASLT_MATMUL_TILE_24x320 = 63, CUBLASLT_MATMUL_TILE_24x384 = 64, CUBLASLT_MATMUL_TILE_24x448 = 65, CUBLASLT_MATMUL_TILE_24x512 = 66, CUBLASLT_MATMUL_TILE_24x576 = 67, CUBLASLT_MATMUL_TILE_24x640 = 68, CUBLASLT_MATMUL_TILE_24x704 = 69, CUBLASLT_MATMUL_TILE_24x768 = 70, CUBLASLT_MATMUL_TILE_32x192 = 71, CUBLASLT_MATMUL_TILE_32x320 = 72, CUBLASLT_MATMUL_TILE_32x384 = 73, CUBLASLT_MATMUL_TILE_32x448 = 74, CUBLASLT_MATMUL_TILE_32x512 = 75, CUBLASLT_MATMUL_TILE_32x576 = 76, CUBLASLT_MATMUL_TILE_32x640 = 77, CUBLASLT_MATMUL_TILE_32x704 = 78, CUBLASLT_MATMUL_TILE_32x768 = 79, CUBLASLT_MATMUL_TILE_40x64 = 80, CUBLASLT_MATMUL_TILE_40x128 = 81, CUBLASLT_MATMUL_TILE_40x192 = 82, CUBLASLT_MATMUL_TILE_40x256 = 83, CUBLASLT_MATMUL_TILE_40x320 = 84, CUBLASLT_MATMUL_TILE_40x384 = 85, CUBLASLT_MATMUL_TILE_40x448 = 86, CUBLASLT_MATMUL_TILE_40x512 = 87, CUBLASLT_MATMUL_TILE_40x576 = 88, CUBLASLT_MATMUL_TILE_40x640 = 89, CUBLASLT_MATMUL_TILE_40x704 = 90, CUBLASLT_MATMUL_TILE_40x768 = 91, CUBLASLT_MATMUL_TILE_48x64 = 92, CUBLASLT_MATMUL_TILE_48x128 = 93, CUBLASLT_MATMUL_TILE_48x192 = 94, CUBLASLT_MATMUL_TILE_48x256 = 95, CUBLASLT_MATMUL_TILE_48x320 = 96, CUBLASLT_MATMUL_TILE_48x384 = 97, CUBLASLT_MATMUL_TILE_48x448 = 98, CUBLASLT_MATMUL_TILE_48x512 = 99, CUBLASLT_MATMUL_TILE_48x576 = 100, CUBLASLT_MATMUL_TILE_48x640 = 101, CUBLASLT_MATMUL_TILE_48x704 = 102, CUBLASLT_MATMUL_TILE_48x768 = 103, CUBLASLT_MATMUL_TILE_56x64 = 104, CUBLASLT_MATMUL_TILE_56x128 = 105, CUBLASLT_MATMUL_TILE_56x192 = 106, CUBLASLT_MATMUL_TILE_56x256 = 107, CUBLASLT_MATMUL_TILE_56x320 = 108, CUBLASLT_MATMUL_TILE_56x384 = 109, CUBLASLT_MATMUL_TILE_56x448 = 110, CUBLASLT_MATMUL_TILE_56x512 = 111, CUBLASLT_MATMUL_TILE_56x576 = 112, CUBLASLT_MATMUL_TILE_56x640 = 113, CUBLASLT_MATMUL_TILE_56x704 = 114, CUBLASLT_MATMUL_TILE_56x768 = 115, CUBLASLT_MATMUL_TILE_64x192 = 116, CUBLASLT_MATMUL_TILE_64x320 = 117, CUBLASLT_MATMUL_TILE_64x384 = 118, CUBLASLT_MATMUL_TILE_64x448 = 119, CUBLASLT_MATMUL_TILE_64x576 = 120, CUBLASLT_MATMUL_TILE_64x640 = 121, CUBLASLT_MATMUL_TILE_64x704 = 122, CUBLASLT_MATMUL_TILE_64x768 = 123, CUBLASLT_MATMUL_TILE_72x64 = 124, CUBLASLT_MATMUL_TILE_72x128 = 125, CUBLASLT_MATMUL_TILE_72x192 = 126, CUBLASLT_MATMUL_TILE_72x256 = 127, CUBLASLT_MATMUL_TILE_72x320 = 128, CUBLASLT_MATMUL_TILE_72x384 = 129, CUBLASLT_MATMUL_TILE_72x448 = 130, CUBLASLT_MATMUL_TILE_72x512 = 131, CUBLASLT_MATMUL_TILE_72x576 = 132, CUBLASLT_MATMUL_TILE_72x640 = 133, CUBLASLT_MATMUL_TILE_80x64 = 134, CUBLASLT_MATMUL_TILE_80x128 = 135, CUBLASLT_MATMUL_TILE_80x192 = 136, CUBLASLT_MATMUL_TILE_80x256 = 137, CUBLASLT_MATMUL_TILE_80x320 = 138, CUBLASLT_MATMUL_TILE_80x384 = 139, CUBLASLT_MATMUL_TILE_80x448 = 140, CUBLASLT_MATMUL_TILE_80x512 = 141, CUBLASLT_MATMUL_TILE_80x576 = 142, CUBLASLT_MATMUL_TILE_88x64 = 143, CUBLASLT_MATMUL_TILE_88x128 = 144, CUBLASLT_MATMUL_TILE_88x192 = 145, CUBLASLT_MATMUL_TILE_88x256 = 146, CUBLASLT_MATMUL_TILE_88x320 = 147, CUBLASLT_MATMUL_TILE_88x384 = 148, CUBLASLT_MATMUL_TILE_88x448 = 149, CUBLASLT_MATMUL_TILE_88x512 = 150, CUBLASLT_MATMUL_TILE_96x192 = 151, CUBLASLT_MATMUL_TILE_96x256 = 152, CUBLASLT_MATMUL_TILE_96x320 = 153, CUBLASLT_MATMUL_TILE_96x384 = 154, CUBLASLT_MATMUL_TILE_96x448 = 155, CUBLASLT_MATMUL_TILE_96x512 = 156, CUBLASLT_MATMUL_TILE_104x64 = 157, CUBLASLT_MATMUL_TILE_104x128 = 158, CUBLASLT_MATMUL_TILE_104x192 = 159, CUBLASLT_MATMUL_TILE_104x256 = 160, CUBLASLT_MATMUL_TILE_104x320 = 161, CUBLASLT_MATMUL_TILE_104x384 = 162, CUBLASLT_MATMUL_TILE_104x448 = 163, CUBLASLT_MATMUL_TILE_112x64 = 164, CUBLASLT_MATMUL_TILE_112x128 = 165, CUBLASLT_MATMUL_TILE_112x192 = 166, CUBLASLT_MATMUL_TILE_112x256 = 167, CUBLASLT_MATMUL_TILE_112x320 = 168, CUBLASLT_MATMUL_TILE_112x384 = 169, CUBLASLT_MATMUL_TILE_120x64 = 170, CUBLASLT_MATMUL_TILE_120x128 = 171, CUBLASLT_MATMUL_TILE_120x192 = 172, CUBLASLT_MATMUL_TILE_120x256 = 173, CUBLASLT_MATMUL_TILE_120x320 = 174, CUBLASLT_MATMUL_TILE_120x384 = 175, CUBLASLT_MATMUL_TILE_128x320 = 176, CUBLASLT_MATMUL_TILE_128x384 = 177, CUBLASLT_MATMUL_TILE_136x64 = 178, CUBLASLT_MATMUL_TILE_136x128 = 179, CUBLASLT_MATMUL_TILE_136x192 = 180, CUBLASLT_MATMUL_TILE_136x256 = 181, CUBLASLT_MATMUL_TILE_136x320 = 182, CUBLASLT_MATMUL_TILE_144x64 = 183, CUBLASLT_MATMUL_TILE_144x128 = 184, CUBLASLT_MATMUL_TILE_144x192 = 185, CUBLASLT_MATMUL_TILE_144x256 = 186, CUBLASLT_MATMUL_TILE_144x320 = 187, CUBLASLT_MATMUL_TILE_152x64 = 188, CUBLASLT_MATMUL_TILE_152x128 = 189, CUBLASLT_MATMUL_TILE_152x192 = 190, CUBLASLT_MATMUL_TILE_152x256 = 191, CUBLASLT_MATMUL_TILE_152x320 = 192, CUBLASLT_MATMUL_TILE_160x64 = 193, CUBLASLT_MATMUL_TILE_160x192 = 194, CUBLASLT_MATMUL_TILE_160x256 = 195, CUBLASLT_MATMUL_TILE_168x64 = 196, CUBLASLT_MATMUL_TILE_168x128 = 197, CUBLASLT_MATMUL_TILE_168x192 = 198, CUBLASLT_MATMUL_TILE_168x256 = 199, CUBLASLT_MATMUL_TILE_176x64 = 200, CUBLASLT_MATMUL_TILE_176x128 = 201, CUBLASLT_MATMUL_TILE_176x192 = 202, CUBLASLT_MATMUL_TILE_176x256 = 203, CUBLASLT_MATMUL_TILE_184x64 = 204, CUBLASLT_MATMUL_TILE_184x128 = 205, CUBLASLT_MATMUL_TILE_184x192 = 206, CUBLASLT_MATMUL_TILE_184x256 = 207, CUBLASLT_MATMUL_TILE_192x64 = 208, CUBLASLT_MATMUL_TILE_192x192 = 209, CUBLASLT_MATMUL_TILE_192x256 = 210, CUBLASLT_MATMUL_TILE_200x64 = 211, CUBLASLT_MATMUL_TILE_200x128 = 212, CUBLASLT_MATMUL_TILE_200x192 = 213, CUBLASLT_MATMUL_TILE_208x64 = 214, CUBLASLT_MATMUL_TILE_208x128 = 215, CUBLASLT_MATMUL_TILE_208x192 = 216, CUBLASLT_MATMUL_TILE_216x64 = 217, CUBLASLT_MATMUL_TILE_216x128 = 218, CUBLASLT_MATMUL_TILE_216x192 = 219, CUBLASLT_MATMUL_TILE_224x64 = 220, CUBLASLT_MATMUL_TILE_224x128 = 221, CUBLASLT_MATMUL_TILE_224x192 = 222, CUBLASLT_MATMUL_TILE_232x64 = 223, CUBLASLT_MATMUL_TILE_232x128 = 224, CUBLASLT_MATMUL_TILE_232x192 = 225, CUBLASLT_MATMUL_TILE_240x64 = 226, CUBLASLT_MATMUL_TILE_240x128 = 227, CUBLASLT_MATMUL_TILE_240x192 = 228, CUBLASLT_MATMUL_TILE_248x64 = 229, CUBLASLT_MATMUL_TILE_248x128 = 230, CUBLASLT_MATMUL_TILE_248x192 = 231, CUBLASLT_MATMUL_TILE_256x192 = 232, CUBLASLT_MATMUL_TILE_264x64 = 233, CUBLASLT_MATMUL_TILE_264x128 = 234, CUBLASLT_MATMUL_TILE_272x64 = 235, CUBLASLT_MATMUL_TILE_272x128 = 236, CUBLASLT_MATMUL_TILE_280x64 = 237, CUBLASLT_MATMUL_TILE_280x128 = 238, CUBLASLT_MATMUL_TILE_288x64 = 239, CUBLASLT_MATMUL_TILE_288x128 = 240, CUBLASLT_MATMUL_TILE_296x64 = 241, CUBLASLT_MATMUL_TILE_296x128 = 242, CUBLASLT_MATMUL_TILE_304x64 = 243, CUBLASLT_MATMUL_TILE_304x128 = 244, CUBLASLT_MATMUL_TILE_312x64 = 245, CUBLASLT_MATMUL_TILE_312x128 = 246, CUBLASLT_MATMUL_TILE_320x64 = 247, CUBLASLT_MATMUL_TILE_320x128 = 248, CUBLASLT_MATMUL_TILE_328x64 = 249, CUBLASLT_MATMUL_TILE_328x128 = 250, CUBLASLT_MATMUL_TILE_336x64 = 251, CUBLASLT_MATMUL_TILE_336x128 = 252, CUBLASLT_MATMUL_TILE_344x64 = 253, CUBLASLT_MATMUL_TILE_344x128 = 254, CUBLASLT_MATMUL_TILE_352x64 = 255, CUBLASLT_MATMUL_TILE_352x128 = 256, CUBLASLT_MATMUL_TILE_360x64 = 257, CUBLASLT_MATMUL_TILE_360x128 = 258, CUBLASLT_MATMUL_TILE_368x64 = 259, CUBLASLT_MATMUL_TILE_368x128 = 260, CUBLASLT_MATMUL_TILE_376x64 = 261, CUBLASLT_MATMUL_TILE_376x128 = 262, CUBLASLT_MATMUL_TILE_384x64 = 263, CUBLASLT_MATMUL_TILE_384x128 = 264, CUBLASLT_MATMUL_TILE_392x64 = 265, CUBLASLT_MATMUL_TILE_400x64 = 266, CUBLASLT_MATMUL_TILE_408x64 = 267, CUBLASLT_MATMUL_TILE_416x64 = 268, CUBLASLT_MATMUL_TILE_424x64 = 269, CUBLASLT_MATMUL_TILE_432x64 = 270, CUBLASLT_MATMUL_TILE_440x64 = 271, CUBLASLT_MATMUL_TILE_448x64 = 272, CUBLASLT_MATMUL_TILE_456x64 = 273, CUBLASLT_MATMUL_TILE_464x64 = 274, CUBLASLT_MATMUL_TILE_472x64 = 275, CUBLASLT_MATMUL_TILE_480x64 = 276, CUBLASLT_MATMUL_TILE_488x64 = 277, CUBLASLT_MATMUL_TILE_496x64 = 278, CUBLASLT_MATMUL_TILE_504x64 = 279, CUBLASLT_MATMUL_TILE_520x64 = 280, CUBLASLT_MATMUL_TILE_528x64 = 281, CUBLASLT_MATMUL_TILE_536x64 = 282, CUBLASLT_MATMUL_TILE_544x64 = 283, CUBLASLT_MATMUL_TILE_552x64 = 284, CUBLASLT_MATMUL_TILE_560x64 = 285, CUBLASLT_MATMUL_TILE_568x64 = 286, CUBLASLT_MATMUL_TILE_576x64 = 287, CUBLASLT_MATMUL_TILE_584x64 = 288, CUBLASLT_MATMUL_TILE_592x64 = 289, CUBLASLT_MATMUL_TILE_600x64 = 290, CUBLASLT_MATMUL_TILE_608x64 = 291, CUBLASLT_MATMUL_TILE_616x64 = 292, CUBLASLT_MATMUL_TILE_624x64 = 293, CUBLASLT_MATMUL_TILE_632x64 = 294, CUBLASLT_MATMUL_TILE_640x64 = 295, CUBLASLT_MATMUL_TILE_648x64 = 296, CUBLASLT_MATMUL_TILE_656x64 = 297, CUBLASLT_MATMUL_TILE_664x64 = 298, CUBLASLT_MATMUL_TILE_672x64 = 299, CUBLASLT_MATMUL_TILE_680x64 = 300, CUBLASLT_MATMUL_TILE_688x64 = 301, CUBLASLT_MATMUL_TILE_696x64 = 302, CUBLASLT_MATMUL_TILE_704x64 = 303, CUBLASLT_MATMUL_TILE_712x64 = 304, CUBLASLT_MATMUL_TILE_720x64 = 305, CUBLASLT_MATMUL_TILE_728x64 = 306, CUBLASLT_MATMUL_TILE_736x64 = 307, CUBLASLT_MATMUL_TILE_744x64 = 308, CUBLASLT_MATMUL_TILE_752x64 = 309, CUBLASLT_MATMUL_TILE_760x64 = 310, CUBLASLT_MATMUL_TILE_768x64 = 311, CUBLASLT_MATMUL_TILE_64x16 = 312, CUBLASLT_MATMUL_TILE_64x24 = 313, CUBLASLT_MATMUL_TILE_64x40 = 314, CUBLASLT_MATMUL_TILE_64x48 = 315, CUBLASLT_MATMUL_TILE_64x56 = 316, CUBLASLT_MATMUL_TILE_64x72 = 317, CUBLASLT_MATMUL_TILE_64x80 = 318, CUBLASLT_MATMUL_TILE_64x88 = 319, CUBLASLT_MATMUL_TILE_64x104 = 320, CUBLASLT_MATMUL_TILE_64x112 = 321, CUBLASLT_MATMUL_TILE_64x120 = 322, CUBLASLT_MATMUL_TILE_64x136 = 323, CUBLASLT_MATMUL_TILE_64x144 = 324, CUBLASLT_MATMUL_TILE_64x152 = 325, CUBLASLT_MATMUL_TILE_64x160 = 326, CUBLASLT_MATMUL_TILE_64x168 = 327, CUBLASLT_MATMUL_TILE_64x176 = 328, CUBLASLT_MATMUL_TILE_64x184 = 329, CUBLASLT_MATMUL_TILE_64x200 = 330, CUBLASLT_MATMUL_TILE_64x208 = 331, CUBLASLT_MATMUL_TILE_64x216 = 332, CUBLASLT_MATMUL_TILE_64x224 = 333, CUBLASLT_MATMUL_TILE_64x232 = 334, CUBLASLT_MATMUL_TILE_64x240 = 335, CUBLASLT_MATMUL_TILE_64x248 = 336, CUBLASLT_MATMUL_TILE_64x264 = 337, CUBLASLT_MATMUL_TILE_64x272 = 338, CUBLASLT_MATMUL_TILE_64x280 = 339, CUBLASLT_MATMUL_TILE_64x288 = 340, CUBLASLT_MATMUL_TILE_64x296 = 341, CUBLASLT_MATMUL_TILE_64x304 = 342, CUBLASLT_MATMUL_TILE_64x312 = 343, CUBLASLT_MATMUL_TILE_64x328 = 344, CUBLASLT_MATMUL_TILE_64x336 = 345, CUBLASLT_MATMUL_TILE_64x344 = 346, CUBLASLT_MATMUL_TILE_64x352 = 347, CUBLASLT_MATMUL_TILE_64x360 = 348, CUBLASLT_MATMUL_TILE_64x368 = 349, CUBLASLT_MATMUL_TILE_64x376 = 350, CUBLASLT_MATMUL_TILE_64x392 = 351, CUBLASLT_MATMUL_TILE_64x400 = 352, CUBLASLT_MATMUL_TILE_64x408 = 353, CUBLASLT_MATMUL_TILE_64x416 = 354, CUBLASLT_MATMUL_TILE_64x424 = 355, CUBLASLT_MATMUL_TILE_64x432 = 356, CUBLASLT_MATMUL_TILE_64x440 = 357, CUBLASLT_MATMUL_TILE_64x456 = 358, CUBLASLT_MATMUL_TILE_64x464 = 359, CUBLASLT_MATMUL_TILE_64x472 = 360, CUBLASLT_MATMUL_TILE_64x480 = 361, CUBLASLT_MATMUL_TILE_64x488 = 362, CUBLASLT_MATMUL_TILE_64x496 = 363, CUBLASLT_MATMUL_TILE_64x504 = 364, CUBLASLT_MATMUL_TILE_64x520 = 365, CUBLASLT_MATMUL_TILE_64x528 = 366, CUBLASLT_MATMUL_TILE_64x536 = 367, CUBLASLT_MATMUL_TILE_64x544 = 368, CUBLASLT_MATMUL_TILE_64x552 = 369, CUBLASLT_MATMUL_TILE_64x560 = 370, CUBLASLT_MATMUL_TILE_64x568 = 371, CUBLASLT_MATMUL_TILE_64x584 = 372, CUBLASLT_MATMUL_TILE_64x592 = 373, CUBLASLT_MATMUL_TILE_64x600 = 374, CUBLASLT_MATMUL_TILE_64x608 = 375, CUBLASLT_MATMUL_TILE_64x616 = 376, CUBLASLT_MATMUL_TILE_64x624 = 377, CUBLASLT_MATMUL_TILE_64x632 = 378, CUBLASLT_MATMUL_TILE_64x648 = 379, CUBLASLT_MATMUL_TILE_64x656 = 380, CUBLASLT_MATMUL_TILE_64x664 = 381, CUBLASLT_MATMUL_TILE_64x672 = 382, CUBLASLT_MATMUL_TILE_64x680 = 383, CUBLASLT_MATMUL_TILE_64x688 = 384, CUBLASLT_MATMUL_TILE_64x696 = 385, CUBLASLT_MATMUL_TILE_64x712 = 386, CUBLASLT_MATMUL_TILE_64x720 = 387, CUBLASLT_MATMUL_TILE_64x728 = 388, CUBLASLT_MATMUL_TILE_64x736 = 389, CUBLASLT_MATMUL_TILE_64x744 = 390, CUBLASLT_MATMUL_TILE_64x752 = 391, CUBLASLT_MATMUL_TILE_64x760 = 392, CUBLASLT_MATMUL_TILE_128x8 = 393, CUBLASLT_MATMUL_TILE_128x16 = 394, CUBLASLT_MATMUL_TILE_128x24 = 395, CUBLASLT_MATMUL_TILE_128x40 = 396, CUBLASLT_MATMUL_TILE_128x48 = 397, CUBLASLT_MATMUL_TILE_128x56 = 398, CUBLASLT_MATMUL_TILE_128x72 = 399, CUBLASLT_MATMUL_TILE_128x80 = 400, CUBLASLT_MATMUL_TILE_128x88 = 401, CUBLASLT_MATMUL_TILE_128x104 = 402, CUBLASLT_MATMUL_TILE_128x112 = 403, CUBLASLT_MATMUL_TILE_128x120 = 404, CUBLASLT_MATMUL_TILE_128x136 = 405, CUBLASLT_MATMUL_TILE_128x144 = 406, CUBLASLT_MATMUL_TILE_128x152 = 407, CUBLASLT_MATMUL_TILE_128x168 = 408, CUBLASLT_MATMUL_TILE_128x176 = 409, CUBLASLT_MATMUL_TILE_128x184 = 410, CUBLASLT_MATMUL_TILE_128x200 = 411, CUBLASLT_MATMUL_TILE_128x208 = 412, CUBLASLT_MATMUL_TILE_128x216 = 413, CUBLASLT_MATMUL_TILE_128x224 = 414, CUBLASLT_MATMUL_TILE_128x232 = 415, CUBLASLT_MATMUL_TILE_128x240 = 416, CUBLASLT_MATMUL_TILE_128x248 = 417, CUBLASLT_MATMUL_TILE_128x264 = 418, CUBLASLT_MATMUL_TILE_128x272 = 419, CUBLASLT_MATMUL_TILE_128x280 = 420, CUBLASLT_MATMUL_TILE_128x288 = 421, CUBLASLT_MATMUL_TILE_128x296 = 422, CUBLASLT_MATMUL_TILE_128x304 = 423, CUBLASLT_MATMUL_TILE_128x312 = 424, CUBLASLT_MATMUL_TILE_128x328 = 425, CUBLASLT_MATMUL_TILE_128x336 = 426, CUBLASLT_MATMUL_TILE_128x344 = 427, CUBLASLT_MATMUL_TILE_128x352 = 428, CUBLASLT_MATMUL_TILE_128x360 = 429, CUBLASLT_MATMUL_TILE_128x368 = 430, CUBLASLT_MATMUL_TILE_128x376 = 431, CUBLASLT_MATMUL_TILE_128x392 = 432, CUBLASLT_MATMUL_TILE_128x400 = 433, CUBLASLT_MATMUL_TILE_128x408 = 434, CUBLASLT_MATMUL_TILE_128x416 = 435, CUBLASLT_MATMUL_TILE_128x424 = 436, CUBLASLT_MATMUL_TILE_128x432 = 437, CUBLASLT_MATMUL_TILE_128x440 = 438, CUBLASLT_MATMUL_TILE_128x448 = 439, CUBLASLT_MATMUL_TILE_128x456 = 440, CUBLASLT_MATMUL_TILE_128x464 = 441, CUBLASLT_MATMUL_TILE_128x472 = 442, CUBLASLT_MATMUL_TILE_128x480 = 443, CUBLASLT_MATMUL_TILE_128x488 = 444, CUBLASLT_MATMUL_TILE_128x496 = 445, CUBLASLT_MATMUL_TILE_128x504 = 446, CUBLASLT_MATMUL_TILE_128x512 = 447, CUBLASLT_MATMUL_TILE_192x8 = 448, CUBLASLT_MATMUL_TILE_192x16 = 449, CUBLASLT_MATMUL_TILE_192x24 = 450, CUBLASLT_MATMUL_TILE_192x32 = 451, CUBLASLT_MATMUL_TILE_192x40 = 452, CUBLASLT_MATMUL_TILE_192x48 = 453, CUBLASLT_MATMUL_TILE_192x56 = 454, CUBLASLT_MATMUL_TILE_192x72 = 455, CUBLASLT_MATMUL_TILE_192x80 = 456, CUBLASLT_MATMUL_TILE_192x88 = 457, CUBLASLT_MATMUL_TILE_192x96 = 458, CUBLASLT_MATMUL_TILE_192x104 = 459, CUBLASLT_MATMUL_TILE_192x112 = 460, CUBLASLT_MATMUL_TILE_192x120 = 461, CUBLASLT_MATMUL_TILE_192x136 = 462, CUBLASLT_MATMUL_TILE_192x144 = 463, CUBLASLT_MATMUL_TILE_192x152 = 464, CUBLASLT_MATMUL_TILE_192x160 = 465, CUBLASLT_MATMUL_TILE_192x168 = 466, CUBLASLT_MATMUL_TILE_192x176 = 467, CUBLASLT_MATMUL_TILE_192x184 = 468, CUBLASLT_MATMUL_TILE_192x200 = 469, CUBLASLT_MATMUL_TILE_192x208 = 470, CUBLASLT_MATMUL_TILE_192x216 = 471, CUBLASLT_MATMUL_TILE_192x224 = 472, CUBLASLT_MATMUL_TILE_192x232 = 473, CUBLASLT_MATMUL_TILE_192x240 = 474, CUBLASLT_MATMUL_TILE_192x248 = 475, CUBLASLT_MATMUL_TILE_192x264 = 476, CUBLASLT_MATMUL_TILE_192x272 = 477, CUBLASLT_MATMUL_TILE_192x280 = 478, CUBLASLT_MATMUL_TILE_192x288 = 479, CUBLASLT_MATMUL_TILE_192x296 = 480, CUBLASLT_MATMUL_TILE_192x304 = 481, CUBLASLT_MATMUL_TILE_192x312 = 482, CUBLASLT_MATMUL_TILE_192x320 = 483, CUBLASLT_MATMUL_TILE_192x328 = 484, CUBLASLT_MATMUL_TILE_192x336 = 485, CUBLASLT_MATMUL_TILE_256x8 = 486, CUBLASLT_MATMUL_TILE_256x16 = 487, CUBLASLT_MATMUL_TILE_256x24 = 488, CUBLASLT_MATMUL_TILE_256x40 = 489, CUBLASLT_MATMUL_TILE_256x48 = 490, CUBLASLT_MATMUL_TILE_256x56 = 491, CUBLASLT_MATMUL_TILE_256x72 = 492, CUBLASLT_MATMUL_TILE_256x80 = 493, CUBLASLT_MATMUL_TILE_256x88 = 494, CUBLASLT_MATMUL_TILE_256x96 = 495, CUBLASLT_MATMUL_TILE_256x104 = 496, CUBLASLT_MATMUL_TILE_256x112 = 497, CUBLASLT_MATMUL_TILE_256x120 = 498, CUBLASLT_MATMUL_TILE_256x136 = 499, CUBLASLT_MATMUL_TILE_256x144 = 500, CUBLASLT_MATMUL_TILE_256x152 = 501, CUBLASLT_MATMUL_TILE_256x160 = 502, CUBLASLT_MATMUL_TILE_256x168 = 503, CUBLASLT_MATMUL_TILE_256x176 = 504, CUBLASLT_MATMUL_TILE_256x184 = 505, CUBLASLT_MATMUL_TILE_256x200 = 506, CUBLASLT_MATMUL_TILE_256x208 = 507, CUBLASLT_MATMUL_TILE_256x216 = 508, CUBLASLT_MATMUL_TILE_256x224 = 509, CUBLASLT_MATMUL_TILE_256x232 = 510, CUBLASLT_MATMUL_TILE_256x240 = 511, CUBLASLT_MATMUL_TILE_256x248 = 512, CUBLASLT_MATMUL_TILE_256x256 = 513, CUBLASLT_MATMUL_TILE_320x8 = 514, CUBLASLT_MATMUL_TILE_320x16 = 515, CUBLASLT_MATMUL_TILE_320x24 = 516, CUBLASLT_MATMUL_TILE_320x32 = 517, CUBLASLT_MATMUL_TILE_320x40 = 518, CUBLASLT_MATMUL_TILE_320x48 = 519, CUBLASLT_MATMUL_TILE_320x56 = 520, CUBLASLT_MATMUL_TILE_320x72 = 521, CUBLASLT_MATMUL_TILE_320x80 = 522, CUBLASLT_MATMUL_TILE_320x88 = 523, CUBLASLT_MATMUL_TILE_320x96 = 524, CUBLASLT_MATMUL_TILE_320x104 = 525, CUBLASLT_MATMUL_TILE_320x112 = 526, CUBLASLT_MATMUL_TILE_320x120 = 527, CUBLASLT_MATMUL_TILE_320x136 = 528, CUBLASLT_MATMUL_TILE_320x144 = 529, CUBLASLT_MATMUL_TILE_320x152 = 530, CUBLASLT_MATMUL_TILE_320x160 = 531, CUBLASLT_MATMUL_TILE_320x168 = 532, CUBLASLT_MATMUL_TILE_320x176 = 533, CUBLASLT_MATMUL_TILE_320x184 = 534, CUBLASLT_MATMUL_TILE_320x192 = 535, CUBLASLT_MATMUL_TILE_320x200 = 536, CUBLASLT_MATMUL_TILE_384x8 = 537, CUBLASLT_MATMUL_TILE_384x16 = 538, CUBLASLT_MATMUL_TILE_384x24 = 539, CUBLASLT_MATMUL_TILE_384x32 = 540, CUBLASLT_MATMUL_TILE_384x40 = 541, CUBLASLT_MATMUL_TILE_384x48 = 542, CUBLASLT_MATMUL_TILE_384x56 = 543, CUBLASLT_MATMUL_TILE_384x72 = 544, CUBLASLT_MATMUL_TILE_384x80 = 545, CUBLASLT_MATMUL_TILE_384x88 = 546, CUBLASLT_MATMUL_TILE_384x96 = 547, CUBLASLT_MATMUL_TILE_384x104 = 548, CUBLASLT_MATMUL_TILE_384x112 = 549, CUBLASLT_MATMUL_TILE_384x120 = 550, CUBLASLT_MATMUL_TILE_384x136 = 551, CUBLASLT_MATMUL_TILE_384x144 = 552, CUBLASLT_MATMUL_TILE_384x152 = 553, CUBLASLT_MATMUL_TILE_384x160 = 554, CUBLASLT_MATMUL_TILE_384x168 = 555, CUBLASLT_MATMUL_TILE_448x8 = 556, CUBLASLT_MATMUL_TILE_448x16 = 557, CUBLASLT_MATMUL_TILE_448x24 = 558, CUBLASLT_MATMUL_TILE_448x32 = 559, CUBLASLT_MATMUL_TILE_448x40 = 560, CUBLASLT_MATMUL_TILE_448x48 = 561, CUBLASLT_MATMUL_TILE_448x56 = 562, CUBLASLT_MATMUL_TILE_448x72 = 563, CUBLASLT_MATMUL_TILE_448x80 = 564, CUBLASLT_MATMUL_TILE_448x88 = 565, CUBLASLT_MATMUL_TILE_448x96 = 566, CUBLASLT_MATMUL_TILE_448x104 = 567, CUBLASLT_MATMUL_TILE_448x112 = 568, CUBLASLT_MATMUL_TILE_448x120 = 569, CUBLASLT_MATMUL_TILE_448x128 = 570, CUBLASLT_MATMUL_TILE_448x136 = 571, CUBLASLT_MATMUL_TILE_448x144 = 572, CUBLASLT_MATMUL_TILE_512x8 = 573, CUBLASLT_MATMUL_TILE_512x16 = 574, CUBLASLT_MATMUL_TILE_512x24 = 575, CUBLASLT_MATMUL_TILE_512x32 = 576, CUBLASLT_MATMUL_TILE_512x40 = 577, CUBLASLT_MATMUL_TILE_512x48 = 578, CUBLASLT_MATMUL_TILE_512x56 = 579, CUBLASLT_MATMUL_TILE_512x72 = 580, CUBLASLT_MATMUL_TILE_512x80 = 581, CUBLASLT_MATMUL_TILE_512x88 = 582, CUBLASLT_MATMUL_TILE_512x96 = 583, CUBLASLT_MATMUL_TILE_512x104 = 584, CUBLASLT_MATMUL_TILE_512x112 = 585, CUBLASLT_MATMUL_TILE_512x120 = 586, CUBLASLT_MATMUL_TILE_512x128 = 587, CUBLASLT_MATMUL_TILE_576x8 = 588, CUBLASLT_MATMUL_TILE_576x16 = 589, CUBLASLT_MATMUL_TILE_576x24 = 590, CUBLASLT_MATMUL_TILE_576x32 = 591, CUBLASLT_MATMUL_TILE_576x40 = 592, CUBLASLT_MATMUL_TILE_576x48 = 593, CUBLASLT_MATMUL_TILE_576x56 = 594, CUBLASLT_MATMUL_TILE_576x72 = 595, CUBLASLT_MATMUL_TILE_576x80 = 596, CUBLASLT_MATMUL_TILE_576x88 = 597, CUBLASLT_MATMUL_TILE_576x96 = 598, CUBLASLT_MATMUL_TILE_576x104 = 599, CUBLASLT_MATMUL_TILE_576x112 = 600, CUBLASLT_MATMUL_TILE_640x8 = 601, CUBLASLT_MATMUL_TILE_640x16 = 602, CUBLASLT_MATMUL_TILE_640x24 = 603, CUBLASLT_MATMUL_TILE_640x32 = 604, CUBLASLT_MATMUL_TILE_640x40 = 605, CUBLASLT_MATMUL_TILE_640x48 = 606, CUBLASLT_MATMUL_TILE_640x56 = 607, CUBLASLT_MATMUL_TILE_640x72 = 608, CUBLASLT_MATMUL_TILE_640x80 = 609, CUBLASLT_MATMUL_TILE_640x88 = 610, CUBLASLT_MATMUL_TILE_640x96 = 611, CUBLASLT_MATMUL_TILE_704x8 = 612, CUBLASLT_MATMUL_TILE_704x16 = 613, CUBLASLT_MATMUL_TILE_704x24 = 614, CUBLASLT_MATMUL_TILE_704x32 = 615, CUBLASLT_MATMUL_TILE_704x40 = 616, CUBLASLT_MATMUL_TILE_704x48 = 617, CUBLASLT_MATMUL_TILE_704x56 = 618, CUBLASLT_MATMUL_TILE_704x72 = 619, CUBLASLT_MATMUL_TILE_704x80 = 620, CUBLASLT_MATMUL_TILE_704x88 = 621, CUBLASLT_MATMUL_TILE_768x8 = 622, CUBLASLT_MATMUL_TILE_768x16 = 623, CUBLASLT_MATMUL_TILE_768x24 = 624, CUBLASLT_MATMUL_TILE_768x32 = 625, CUBLASLT_MATMUL_TILE_768x40 = 626, CUBLASLT_MATMUL_TILE_768x48 = 627, CUBLASLT_MATMUL_TILE_768x56 = 628, CUBLASLT_MATMUL_TILE_768x72 = 629, CUBLASLT_MATMUL_TILE_768x80 = 630, CUBLASLT_MATMUL_TILE_256x512 = 631, CUBLASLT_MATMUL_TILE_256x1024 = 632, CUBLASLT_MATMUL_TILE_512x512 = 633, CUBLASLT_MATMUL_TILE_512x1024 = 634, CUBLASLT_MATMUL_TILE_END } cublasLtMatmulTile_t; /** Size and number of stages in which elements are read into shared memory * * General order of stages IDs is sorted by stage size first and by number of stages second. */ typedef enum { CUBLASLT_MATMUL_STAGES_UNDEFINED = 0, CUBLASLT_MATMUL_STAGES_16x1 = 1, CUBLASLT_MATMUL_STAGES_16x2 = 2, CUBLASLT_MATMUL_STAGES_16x3 = 3, CUBLASLT_MATMUL_STAGES_16x4 = 4, CUBLASLT_MATMUL_STAGES_16x5 = 5, CUBLASLT_MATMUL_STAGES_16x6 = 6, CUBLASLT_MATMUL_STAGES_32x1 = 7, CUBLASLT_MATMUL_STAGES_32x2 = 8, CUBLASLT_MATMUL_STAGES_32x3 = 9, CUBLASLT_MATMUL_STAGES_32x4 = 10, CUBLASLT_MATMUL_STAGES_32x5 = 11, CUBLASLT_MATMUL_STAGES_32x6 = 12, CUBLASLT_MATMUL_STAGES_64x1 = 13, CUBLASLT_MATMUL_STAGES_64x2 = 14, CUBLASLT_MATMUL_STAGES_64x3 = 15, CUBLASLT_MATMUL_STAGES_64x4 = 16, CUBLASLT_MATMUL_STAGES_64x5 = 17, CUBLASLT_MATMUL_STAGES_64x6 = 18, CUBLASLT_MATMUL_STAGES_128x1 = 19, CUBLASLT_MATMUL_STAGES_128x2 = 20, CUBLASLT_MATMUL_STAGES_128x3 = 21, CUBLASLT_MATMUL_STAGES_128x4 = 22, CUBLASLT_MATMUL_STAGES_128x5 = 23, CUBLASLT_MATMUL_STAGES_128x6 = 24, CUBLASLT_MATMUL_STAGES_32x10 = 25, CUBLASLT_MATMUL_STAGES_8x4 = 26, CUBLASLT_MATMUL_STAGES_16x10 = 27, CUBLASLT_MATMUL_STAGES_8x5 = 28, CUBLASLT_MATMUL_STAGES_8x3 = 31, CUBLASLT_MATMUL_STAGES_8xAUTO = 32, CUBLASLT_MATMUL_STAGES_16xAUTO = 33, CUBLASLT_MATMUL_STAGES_32xAUTO = 34, CUBLASLT_MATMUL_STAGES_64xAUTO = 35, CUBLASLT_MATMUL_STAGES_128xAUTO = 36, CUBLASLT_MATMUL_STAGES_256xAUTO = 37, CUBLASLT_MATMUL_STAGES_END } cublasLtMatmulStages_t; /** Thread Block Cluster size * * Typically dimensioned similar to cublasLtMatmulTile_t, with the third coordinate unused at this time. */ typedef enum { /** Let library pick cluster shape automatically */ CUBLASLT_CLUSTER_SHAPE_AUTO = 0, CUBLASLT_CLUSTER_SHAPE_1x1x1 = 2, CUBLASLT_CLUSTER_SHAPE_2x1x1 = 3, CUBLASLT_CLUSTER_SHAPE_4x1x1 = 4, CUBLASLT_CLUSTER_SHAPE_1x2x1 = 5, CUBLASLT_CLUSTER_SHAPE_2x2x1 = 6, CUBLASLT_CLUSTER_SHAPE_4x2x1 = 7, CUBLASLT_CLUSTER_SHAPE_1x4x1 = 8, CUBLASLT_CLUSTER_SHAPE_2x4x1 = 9, CUBLASLT_CLUSTER_SHAPE_4x4x1 = 10, CUBLASLT_CLUSTER_SHAPE_8x1x1 = 11, CUBLASLT_CLUSTER_SHAPE_1x8x1 = 12, CUBLASLT_CLUSTER_SHAPE_8x2x1 = 13, CUBLASLT_CLUSTER_SHAPE_2x8x1 = 14, CUBLASLT_CLUSTER_SHAPE_16x1x1 = 15, CUBLASLT_CLUSTER_SHAPE_1x16x1 = 16, CUBLASLT_CLUSTER_SHAPE_3x1x1 = 17, CUBLASLT_CLUSTER_SHAPE_5x1x1 = 18, CUBLASLT_CLUSTER_SHAPE_6x1x1 = 19, CUBLASLT_CLUSTER_SHAPE_7x1x1 = 20, CUBLASLT_CLUSTER_SHAPE_9x1x1 = 21, CUBLASLT_CLUSTER_SHAPE_10x1x1 = 22, CUBLASLT_CLUSTER_SHAPE_11x1x1 = 23, CUBLASLT_CLUSTER_SHAPE_12x1x1 = 24, CUBLASLT_CLUSTER_SHAPE_13x1x1 = 25, CUBLASLT_CLUSTER_SHAPE_14x1x1 = 26, CUBLASLT_CLUSTER_SHAPE_15x1x1 = 27, CUBLASLT_CLUSTER_SHAPE_3x2x1 = 28, CUBLASLT_CLUSTER_SHAPE_5x2x1 = 29, CUBLASLT_CLUSTER_SHAPE_6x2x1 = 30, CUBLASLT_CLUSTER_SHAPE_7x2x1 = 31, CUBLASLT_CLUSTER_SHAPE_1x3x1 = 32, CUBLASLT_CLUSTER_SHAPE_2x3x1 = 33, CUBLASLT_CLUSTER_SHAPE_3x3x1 = 34, CUBLASLT_CLUSTER_SHAPE_4x3x1 = 35, CUBLASLT_CLUSTER_SHAPE_5x3x1 = 36, CUBLASLT_CLUSTER_SHAPE_3x4x1 = 37, CUBLASLT_CLUSTER_SHAPE_1x5x1 = 38, CUBLASLT_CLUSTER_SHAPE_2x5x1 = 39, CUBLASLT_CLUSTER_SHAPE_3x5x1 = 40, CUBLASLT_CLUSTER_SHAPE_1x6x1 = 41, CUBLASLT_CLUSTER_SHAPE_2x6x1 = 42, CUBLASLT_CLUSTER_SHAPE_1x7x1 = 43, CUBLASLT_CLUSTER_SHAPE_2x7x1 = 44, CUBLASLT_CLUSTER_SHAPE_1x9x1 = 45, CUBLASLT_CLUSTER_SHAPE_1x10x1 = 46, CUBLASLT_CLUSTER_SHAPE_1x11x1 = 47, CUBLASLT_CLUSTER_SHAPE_1x12x1 = 48, CUBLASLT_CLUSTER_SHAPE_1x13x1 = 49, CUBLASLT_CLUSTER_SHAPE_1x14x1 = 50, CUBLASLT_CLUSTER_SHAPE_1x15x1 = 51, CUBLASLT_CLUSTER_SHAPE_END } cublasLtClusterShape_t; /** Inner size of the kernel * * Represents various aspects of internal kernel design, that don't impact CUDA grid size but may have other more subtle * effects. * */ typedef enum { CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED = 0, CUBLASLT_MATMUL_INNER_SHAPE_MMA884 = 1, CUBLASLT_MATMUL_INNER_SHAPE_MMA1684 = 2, CUBLASLT_MATMUL_INNER_SHAPE_MMA1688 = 3, CUBLASLT_MATMUL_INNER_SHAPE_MMA16816 = 4, CUBLASLT_MATMUL_INNER_SHAPE_END } cublasLtMatmulInnerShape_t; /** Scaling mode for per-matrix scaling */ typedef enum { /** Scaling factors are single precision scalars applied to the whole tensor */ CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F = 0, /** Scaling factors are tensors that contain a dedicated scaling factor stored as an 8-bit CUDA_R_8F_UE4M3 value for each 16-element block in the innermost dimension of the corresponding data tensor */ CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3 = 1, /** Same as above, except that scaling factor tensor elements have type CUDA_R_8F_UE8M0 and the block size is 32 elements*/ CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0 = 2, CUBLASLT_MATMUL_MATRIX_SCALE_END } cublasLtMatmulMatrixScale_t; /** Pointer mode to use for alpha/beta */ typedef enum { /** matches CUBLAS_POINTER_MODE_HOST, pointer targets a single value host memory */ CUBLASLT_POINTER_MODE_HOST = CUBLAS_POINTER_MODE_HOST, /** matches CUBLAS_POINTER_MODE_DEVICE, pointer targets a single value device memory */ CUBLASLT_POINTER_MODE_DEVICE = CUBLAS_POINTER_MODE_DEVICE, /** pointer targets an array in device memory */ CUBLASLT_POINTER_MODE_DEVICE_VECTOR = 2, /** alpha pointer targets an array in device memory, beta is zero. Note: CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is not supported, must be 0. */ CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO = 3, /** alpha pointer targets an array in device memory, beta is a single value in host memory. */ CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST = 4, } cublasLtPointerMode_t; /** Mask to define pointer mode capability */ typedef enum { /** see CUBLASLT_POINTER_MODE_HOST */ CUBLASLT_POINTER_MODE_MASK_HOST = 1, /** see CUBLASLT_POINTER_MODE_DEVICE */ CUBLASLT_POINTER_MODE_MASK_DEVICE = 2, /** see CUBLASLT_POINTER_MODE_DEVICE_VECTOR */ CUBLASLT_POINTER_MODE_MASK_DEVICE_VECTOR = 4, /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO */ CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_ZERO = 8, /** see CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST */ CUBLASLT_POINTER_MODE_MASK_ALPHA_DEVICE_VECTOR_BETA_HOST = 16, } cublasLtPointerModeMask_t; /** Implementation details that may affect numerical behavior of algorithms. */ #define CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA (0x01ull << 0) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA (0x02ull << 0) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_IMMA (0x04ull << 0) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_DMMA (0x08ull << 0) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_TENSOR_OP_MASK (0xfeull << 0) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_TYPE_MASK (0xffull << 0) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_16F (0x01ull << 8) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32F (0x02ull << 8) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_64F (0x04ull << 8) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_32I (0x08ull << 8) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_ACCUMULATOR_TYPE_MASK (0xffull << 8) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16F (0x01ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_16BF (0x02ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_TF32 (0x04ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_32F (0x08ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_64F (0x10ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8I (0x20ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E4M3 (0x40ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_INPUT_8F_E5M2 (0x80ull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_OP_INPUT_TYPE_MASK (0xffull << 16) #define CUBLASLT_NUMERICAL_IMPL_FLAGS_GAUSSIAN (0x01ull << 32) typedef uint64_t cublasLtNumericalImplFlags_t; /** Execute matrix multiplication (D = alpha * op(A) * op(B) + beta * C). * * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g. * when workspaceSizeInBytes is less than workspace required by configured * algo * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured * operation * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatmul(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t computeDesc, const void* alpha, /* host or device pointer */ const void* A, cublasLtMatrixLayout_t Adesc, const void* B, cublasLtMatrixLayout_t Bdesc, const void* beta, /* host or device pointer */ const void* C, cublasLtMatrixLayout_t Cdesc, void* D, cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t* algo, void* workspace, size_t workspaceSizeInBytes, cudaStream_t stream); /** Matrix layout conversion helper (C = alpha * op(A) + beta * op(B)) * * Can be used to change memory order of data or to scale and shift the values. * * \retval CUBLAS_STATUS_NOT_INITIALIZED if cuBLASLt handle has not been initialized * \retval CUBLAS_STATUS_INVALID_VALUE if parameters are in conflict or in an impossible configuration; e.g. * when A is not NULL, but Adesc is NULL * \retval CUBLAS_STATUS_NOT_SUPPORTED if current implementation on selected device doesn't support configured * operation * \retval CUBLAS_STATUS_ARCH_MISMATCH if configured operation cannot be run using selected device * \retval CUBLAS_STATUS_EXECUTION_FAILED if cuda reported execution error from the device * \retval CUBLAS_STATUS_SUCCESS if the operation completed successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransform(cublasLtHandle_t lightHandle, cublasLtMatrixTransformDesc_t transformDesc, const void* alpha, /* host or device pointer */ const void* A, cublasLtMatrixLayout_t Adesc, const void* beta, /* host or device pointer */ const void* B, cublasLtMatrixLayout_t Bdesc, void* C, cublasLtMatrixLayout_t Cdesc, cudaStream_t stream); /* ---------------------------------------------------------------------------------------*/ /* Helper functions for cublasLtMatrixLayout_t */ /* ---------------------------------------------------------------------------------------*/ /** Enum for data ordering */ typedef enum { /** Column-major * * Leading dimension is the stride (in elements) to the beginning of next column in memory. */ CUBLASLT_ORDER_COL = 0, /** Row major * * Leading dimension is the stride (in elements) to the beginning of next row in memory. */ CUBLASLT_ORDER_ROW = 1, /** Column-major ordered tiles of 32 columns. * * Leading dimension is the stride (in elements) to the beginning of next group of 32-columns. E.g. if matrix has 33 * columns and 2 rows, ld must be at least (32) * 2 = 64. */ CUBLASLT_ORDER_COL32 = 2, /** Column-major ordered tiles of composite tiles with total 32 columns and 8 rows, tile composed of interleaved * inner tiles of 4 columns within 4 even or odd rows in an alternating pattern. * * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 8 row tile for the next * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32 * 8) * 1 = 256. */ CUBLASLT_ORDER_COL4_4R2_8C = 3, /** Column-major ordered tiles of composite tiles with total 32 columns ands 32 rows. * Element offset within the tile is calculated as (((row%8)/2*4+row/8)*2+row%2)*32+col. * * Leading dimension is the stride (in elements) to the beginning of the first 32 column x 32 row tile for the next * 32-wide group of columns. E.g. if matrix has 33 columns and 1 row, ld must be at least (32*32)*1 = 1024. */ CUBLASLT_ORDER_COL32_2R_4R4 = 4, } cublasLtOrder_t; /** Attributes of memory layout */ typedef enum { /** Data type, see cudaDataType. * * uint32_t */ CUBLASLT_MATRIX_LAYOUT_TYPE = 0, /** Memory order of the data, see cublasLtOrder_t. * * int32_t, default: CUBLASLT_ORDER_COL */ CUBLASLT_MATRIX_LAYOUT_ORDER = 1, /** Number of rows. * * Usually only values that can be expressed as int32_t are supported. * * uint64_t */ CUBLASLT_MATRIX_LAYOUT_ROWS = 2, /** Number of columns. * * Usually only values that can be expressed as int32_t are supported. * * uint64_t */ CUBLASLT_MATRIX_LAYOUT_COLS = 3, /** Matrix leading dimension. * * For CUBLASLT_ORDER_COL this is stride (in elements) of matrix column, for more details and documentation for * other memory orders see documentation for cublasLtOrder_t values. * * Currently only non-negative values are supported, must be large enough so that matrix memory locations are not * overlapping (e.g. greater or equal to CUBLASLT_MATRIX_LAYOUT_ROWS in case of CUBLASLT_ORDER_COL). * * int64_t; */ CUBLASLT_MATRIX_LAYOUT_LD = 4, /** Number of matmul operations to perform in the batch. * * See also CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT * * int32_t, default: 1 */ CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT = 5, /** Stride (in elements) to the next matrix for strided batch operation. * * When matrix type is planar-complex (CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET != 0), batch stride * is interpreted by cublasLtMatmul() in number of real valued sub-elements. E.g. for data of type CUDA_C_16F, * offset of 1024B is encoded as a stride of value 512 (since each element of the real and imaginary matrices * is a 2B (16bit) floating point type). * * NOTE: A bug in cublasLtMatrixTransform() causes it to interpret the batch stride for a planar-complex matrix * as if it was specified in number of complex elements. Therefore an offset of 1024B must be encoded as stride * value 256 when calling cublasLtMatrixTransform() (each complex element is 4B with real and imaginary values 2B * each). This behavior is expected to be corrected in the next major cuBLAS version. * * int64_t, default: 0 */ CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET = 6, /** Stride (in bytes) to the imaginary plane for planar complex layout. * * int64_t, default: 0 - 0 means that layout is regular (real and imaginary parts of complex numbers are interleaved * in memory in each element) */ CUBLASLT_MATRIX_LAYOUT_PLANE_OFFSET = 7, } cublasLtMatrixLayoutAttribute_t; /** Internal. Do not use directly. */ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutInit_internal( // cublasLtMatrixLayout_t matLayout, size_t size, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld); /** Initialize matrix layout descriptor in pre-allocated space. * * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ static inline cublasStatus_t cublasLtMatrixLayoutInit( cublasLtMatrixLayout_t matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld) { return cublasLtMatrixLayoutInit_internal(matLayout, sizeof(*matLayout), type, rows, cols, ld); } /** Create new matrix layout descriptor. * * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutCreate( // cublasLtMatrixLayout_t* matLayout, cudaDataType type, uint64_t rows, uint64_t cols, int64_t ld); /** Destroy matrix layout descriptor. * * \retval CUBLAS_STATUS_SUCCESS if operation was successful */ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutDestroy(cublasLtMatrixLayout_t matLayout); /** Set matrix layout descriptor attribute. * * \param[in] matLayout The descriptor * \param[in] attr The attribute * \param[in] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutSetAttribute( // cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr, const void* buf, size_t sizeInBytes); /** Get matrix layout descriptor attribute. * * \param[in] matLayout The descriptor * \param[in] attr The attribute * \param[out] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents * * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero * and buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory */ cublasStatus_t CUBLASWINAPI cublasLtMatrixLayoutGetAttribute( // cublasLtMatrixLayout_t matLayout, cublasLtMatrixLayoutAttribute_t attr, void* buf, size_t sizeInBytes, size_t* sizeWritten); /* ---------------------------------------------------------------------------------------*/ /* Helper functions for cublasLtMatmulDesc_t */ /* ---------------------------------------------------------------------------------------*/ /** Matmul descriptor attributes to define details of the operation. */ typedef enum { /** Compute type, see cudaDataType. Defines data type used for multiply and accumulate operations and the * accumulator during matrix multiplication. * * int32_t */ CUBLASLT_MATMUL_DESC_COMPUTE_TYPE = 0, /** Scale type, see cudaDataType. Defines data type of alpha and beta. Accumulator and value from matrix C are * typically converted to scale type before final scaling. Value is then converted from scale type to type of matrix * D before being stored in memory. * * int32_t, default: same as CUBLASLT_MATMUL_DESC_COMPUTE_TYPE */ CUBLASLT_MATMUL_DESC_SCALE_TYPE = 1, /** Pointer mode of alpha and beta, see cublasLtPointerMode_t. When CUBLASLT_POINTER_MODE_DEVICE_VECTOR is in use, * alpha/beta vector lenghts must match number of output matrix rows. * * int32_t, default: CUBLASLT_POINTER_MODE_HOST */ CUBLASLT_MATMUL_DESC_POINTER_MODE = 2, /** Transform of matrix A, see cublasOperation_t. * * int32_t, default: CUBLAS_OP_N */ CUBLASLT_MATMUL_DESC_TRANSA = 3, /** Transform of matrix B, see cublasOperation_t. * * int32_t, default: CUBLAS_OP_N */ CUBLASLT_MATMUL_DESC_TRANSB = 4, /** Transform of matrix C, see cublasOperation_t. * * Currently only CUBLAS_OP_N is supported. * * int32_t, default: CUBLAS_OP_N */ CUBLASLT_MATMUL_DESC_TRANSC = 5, /** Matrix fill mode, see cublasFillMode_t. * * int32_t, default: CUBLAS_FILL_MODE_FULL */ CUBLASLT_MATMUL_DESC_FILL_MODE = 6, /** Epilogue function, see cublasLtEpilogue_t. * * uint32_t, default: CUBLASLT_EPILOGUE_DEFAULT */ CUBLASLT_MATMUL_DESC_EPILOGUE = 7, /** Bias or bias gradient vector pointer in the device memory. * * Bias case. See CUBLASLT_EPILOGUE_BIAS. * For bias data type see CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE. * * Bias vector length must match matrix D rows count. * * Bias gradient case. See CUBLASLT_EPILOGUE_DRELU_BGRAD and CUBLASLT_EPILOGUE_DGELU_BGRAD. * Bias gradient vector elements are the same type as the output elements * (Ctype) with the exception of IMMA kernels (see above). * * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic() * depend on its value to determine expected pointer alignment. * * Bias case: const void *, default: NULL * Bias gradient case: void *, default: NULL */ CUBLASLT_MATMUL_DESC_BIAS_POINTER = 8, /** Batch stride for bias or bias gradient vector. * * Used together with CUBLASLT_MATMUL_DESC_BIAS_POINTER when matrix D's CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. * * int64_t, default: 0 */ CUBLASLT_MATMUL_DESC_BIAS_BATCH_STRIDE = 10, /** Pointer for epilogue auxiliary buffer. * * - Output vector for ReLu bit-mask in forward pass when CUBLASLT_EPILOGUE_RELU_AUX * or CUBLASLT_EPILOGUE_RELU_AUX_BIAS epilogue is used. * - Input vector for ReLu bit-mask in backward pass when * CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is used. * * - Output of GELU input matrix in forward pass when * CUBLASLT_EPILOGUE_GELU_AUX_BIAS epilogue is used. * - Input of GELU input matrix for backward pass when * CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue is used. * * For aux data type see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE. * * Routines that don't dereference this pointer, like cublasLtMatmulAlgoGetHeuristic() * depend on its value to determine expected pointer alignment. * * Requires setting CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD attribute. * * Forward pass: void *, default: NULL * Backward pass: const void *, default: NULL */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER = 11, /** Leading dimension for epilogue auxiliary buffer. * * - ReLu bit-mask matrix leading dimension in elements (i.e. bits) * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is * used. Must be divisible by 128 and be no less than the number of rows in the output matrix. * * - GELU input matrix leading dimension in elements * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used. * Must be divisible by 8 and be no less than the number of rows in the output matrix. * * int64_t, default: 0 */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD = 12, /** Batch stride for epilogue auxiliary buffer. * * - ReLu bit-mask matrix batch stride in elements (i.e. bits) * when CUBLASLT_EPILOGUE_RELU_AUX, CUBLASLT_EPILOGUE_RELU_AUX_BIAS or CUBLASLT_EPILOGUE_DRELU_BGRAD epilogue is * used. Must be divisible by 128. * * - GELU input matrix batch stride in elements * when CUBLASLT_EPILOGUE_GELU_AUX_BIAS or CUBLASLT_EPILOGUE_DGELU_BGRAD epilogue used. * Must be divisible by 8. * * int64_t, default: 0 */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_BATCH_STRIDE = 13, /** Batch stride for alpha vector. * * Used together with CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST when matrix D's * CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT > 1. If CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO is set then * CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE must be set to 0 as this mode doesnt supported batched alpha vector. * * int64_t, default: 0 */ CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE = 14, /** Number of SMs to target for parallel execution. Optimizes heuristics for execution on a different number of SMs * when user expects a concurrent stream to be using some of the device resources. * * int32_t, default: 0 - use the number reported by the device. */ CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET = 15, /** Device pointer to the scale factor value that converts data in matrix A to the compute data type range. * * The scaling factor value must have the same type as the compute type. * * If not specified, or set to NULL, the scaling factor is assumed to be 1. * * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() * will return CUBLAS_INVALID_VALUE. * * const void *, default: NULL */ CUBLASLT_MATMUL_DESC_A_SCALE_POINTER = 17, /** Device pointer to the scale factor value to convert data in matrix B to compute data type range. * * The scaling factor value must have the same type as the compute type. * * If not specified, or set to NULL, the scaling factor is assumed to be 1. * * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() * will return CUBLAS_INVALID_VALUE. * * const void *, default: NULL */ CUBLASLT_MATMUL_DESC_B_SCALE_POINTER = 18, /** Device pointer to the scale factor value to convert data in matrix C to compute data type range. * * The scaling factor value must have the same type as the compute type. * * If not specified, or set to NULL, the scaling factor is assumed to be 1. * * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() * will return CUBLAS_INVALID_VALUE. * * const void *, default: NULL */ CUBLASLT_MATMUL_DESC_C_SCALE_POINTER = 19, /** Device pointer to the scale factor value to convert data in matrix D to compute data type range. * * The scaling factor value must have the same type as the compute type. * * If not specified, or set to NULL, the scaling factor is assumed to be 1. * * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() * will return CUBLAS_INVALID_VALUE. * * const void *, default: NULL */ CUBLASLT_MATMUL_DESC_D_SCALE_POINTER = 20, /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the * output matrix. * * The computed value has the same type as the compute type. * * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE. * * void *, default: NULL */ CUBLASLT_MATMUL_DESC_AMAX_D_POINTER = 21, /** Type of the data to be stored to the memory pointed to by CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. * * If unset, the data type defaults to the type of elements of the output matrix with some exceptions, see details * below. * * ReLu uses a bit-mask. * * GELU input matrix elements type is the same as the type of elements of * the output matrix with some exceptions, see details below. * * For fp8 kernels with output type CUDA_R_8F_E4M3 the aux data type can be CUDA_R_8F_E4M3 or CUDA_R_16F with some * restrictions. See https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmulDescAttributes_t for more details. * * If set for an unsupported matrix data, scale, and compute type combination, calling cublasLtMatmul() * will return CUBLAS_INVALID_VALUE. * * int32_t based on cudaDataType, default: -1 */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_DATA_TYPE = 22, /** Device pointer to the scaling factor value to convert results from compute type data range to storage * data range in the auxiliary matrix that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. * * The scaling factor value must have the same type as the compute type. * * If not specified, or set to NULL, the scaling factor is assumed to be 1. If set for an unsupported matrix data, * scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE. * * void *, default: NULL */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_POINTER = 23, /** Device pointer to the memory location that on completion will be set to the maximum of absolute values in the * buffer that is set via CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. * * The computed value has the same type as the compute type. * * If not specified or set to NULL, the maximum absolute value is not computed. If set for an unsupported matrix * data, scale, and compute type combination, calling cublasLtMatmul() will return CUBLAS_INVALID_VALUE. * * void *, default: NULL */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_AMAX_POINTER = 24, /** Flag for managing fp8 fast accumulation mode. * When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results * will not periodically be promoted to a higher precision. * * int8_t, default: 0 - fast accumulation mode is disabled. */ CUBLASLT_MATMUL_DESC_FAST_ACCUM = 25, /** Type of bias or bias gradient vector in the device memory. * * Bias case: see CUBLASLT_EPILOGUE_BIAS. * * Bias vector elements are the same type as the elements of output matrix (Dtype) with the following exceptions: * - IMMA kernels with computeType=CUDA_R_32I and Ctype=CUDA_R_8I where the bias vector elements * are the same type as alpha, beta (CUBLASLT_MATMUL_DESC_SCALE_TYPE=CUDA_R_32F) * - fp8 kernels with an output type of CUDA_R_32F, CUDA_R_8F_E4M3 or CUDA_R_8F_E5M2, See * https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul for details. * * int32_t based on cudaDataType, default: -1 */ CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE = 26, /** EXPERIMENTAL, DEPRECATED: Number of atomic synchronization chunks in the row dimension of the output matrix D. * * int32_t, default 0 (atomic synchronization disabled) */ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_ROWS = 27, /** EXPERIMENTAL, DEPRECATED: Number of atomic synchronization chunks in the column dimension of the output matrix D. * * int32_t, default 0 (atomic synchronization disabled) */ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_NUM_CHUNKS_D_COLS = 28, /** EXPERIMENTAL: Pointer to a device array of input atomic counters consumed by a matmul. * * int32_t *, default: NULL * */ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_IN_COUNTERS_POINTER = 29, /** EXPERIMENTAL: Pointer to a device array of output atomic counters produced by a matmul. * * int32_t *, default: NULL * */ CUBLASLT_MATMUL_DESC_ATOMIC_SYNC_OUT_COUNTERS_POINTER = 30, /** Scaling mode that defines how the matrix scaling factor for matrix A is interpreted * * int32_t, default: 0 */ CUBLASLT_MATMUL_DESC_A_SCALE_MODE = 31, /** Scaling mode that defines how the matrix scaling factor for matrix B is interpreted * * int32_t, default: 0 */ CUBLASLT_MATMUL_DESC_B_SCALE_MODE = 32, /** Scaling mode that defines how the matrix scaling factor for matrix C is interpreted * * int32_t, default: 0 */ CUBLASLT_MATMUL_DESC_C_SCALE_MODE = 33, /** Scaling mode that defines how the matrix scaling factor for matrix D is interpreted * * int32_t, default: 0 */ CUBLASLT_MATMUL_DESC_D_SCALE_MODE = 34, /** Scaling mode that defines how the matrix scaling factor for the auxiliary matrix is interpreted * * int32_t, default: 0 */ CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_SCALE_MODE = 35, /** Device pointer to the scale factors that are used to convert data in matrix D to the compute data type range. * * The scaling factor value type is defined by the scaling mode (see CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE) * * If set for an unsupported matrix data, scale, scale mode, and compute type combination, calling cublasLtMatmul() * will return CUBLAS_INVALID_VALUE. * * void *, default: NULL */ CUBLASLT_MATMUL_DESC_D_OUT_SCALE_POINTER = 36, /** Scaling mode that defines how the output matrix scaling factor for matrix D is interpreted * * int32_t, default: 0 */ CUBLASLT_MATMUL_DESC_D_OUT_SCALE_MODE = 37, } cublasLtMatmulDescAttributes_t; /** Internal. Do not use directly. */ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescInit_internal( // cublasLtMatmulDesc_t matmulDesc, size_t size, cublasComputeType_t computeType, cudaDataType_t scaleType); /** Initialize matmul operation descriptor in pre-allocated space. * * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient * \retval CUBLAS_STATUS_SUCCESS if desciptor was initialized successfully */ static inline cublasStatus_t cublasLtMatmulDescInit( // cublasLtMatmulDesc_t matmulDesc, cublasComputeType_t computeType, cudaDataType_t scaleType) { return cublasLtMatmulDescInit_internal(matmulDesc, sizeof(*matmulDesc), computeType, scaleType); } /** Create new matmul operation descriptor. * * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescCreate(cublasLtMatmulDesc_t* matmulDesc, cublasComputeType_t computeType, cudaDataType_t scaleType); /** Destroy matmul operation descriptor. * * \retval CUBLAS_STATUS_SUCCESS if operation was successful */ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescDestroy(cublasLtMatmulDesc_t matmulDesc); /** Set matmul operation descriptor attribute. * * \param[in] matmulDesc The descriptor * \param[in] attr The attribute * \param[in] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescSetAttribute( // cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr, const void* buf, size_t sizeInBytes); /** Get matmul operation descriptor attribute. * * \param[in] matmulDesc The descriptor * \param[in] attr The attribute * \param[out] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents * * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero * and buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory */ cublasStatus_t CUBLASWINAPI cublasLtMatmulDescGetAttribute( // cublasLtMatmulDesc_t matmulDesc, cublasLtMatmulDescAttributes_t attr, void* buf, size_t sizeInBytes, size_t* sizeWritten); /* ---------------------------------------------------------------------------------------*/ /* Helper functions for cublasLtMatrixTransformDesc_t */ /* ---------------------------------------------------------------------------------------*/ /** Matrix transform descriptor attributes to define details of the operation. */ typedef enum { /** Scale type, see cudaDataType. Inputs are converted to scale type for scaling and summation and results are then * converted to output type to store in memory. * * int32_t */ CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE, /** Pointer mode of alpha and beta, see cublasLtPointerMode_t. * * int32_t, default: CUBLASLT_POINTER_MODE_HOST */ CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, /** Transform of matrix A, see cublasOperation_t. * * int32_t, default: CUBLAS_OP_N */ CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, /** Transform of matrix B, see cublasOperation_t. * * int32_t, default: CUBLAS_OP_N */ CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSB, } cublasLtMatrixTransformDescAttributes_t; /** Internal. Do not use directly. */ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescInit_internal(cublasLtMatrixTransformDesc_t transformDesc, size_t size, cudaDataType scaleType); /** Initialize matrix transform operation descriptor in pre-allocated space. * * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ static inline cublasStatus_t cublasLtMatrixTransformDescInit(cublasLtMatrixTransformDesc_t transformDesc, cudaDataType scaleType) { return cublasLtMatrixTransformDescInit_internal(transformDesc, sizeof(*transformDesc), scaleType); } /** Create new matrix transform operation descriptor. * * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescCreate(cublasLtMatrixTransformDesc_t* transformDesc, cudaDataType scaleType); /** Destroy matrix transform operation descriptor. * * \retval CUBLAS_STATUS_SUCCESS if operation was successful */ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescDestroy(cublasLtMatrixTransformDesc_t transformDesc); /** Set matrix transform operation descriptor attribute. * * \param[in] transformDesc The descriptor * \param[in] attr The attribute * \param[in] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescSetAttribute( // cublasLtMatrixTransformDesc_t transformDesc, cublasLtMatrixTransformDescAttributes_t attr, const void* buf, size_t sizeInBytes); /** Get matrix transform operation descriptor attribute. * * \param[in] transformDesc The descriptor * \param[in] attr The attribute * \param[out] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number * of bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents * * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero * and buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory */ cublasStatus_t CUBLASWINAPI cublasLtMatrixTransformDescGetAttribute( // cublasLtMatrixTransformDesc_t transformDesc, cublasLtMatrixTransformDescAttributes_t attr, void* buf, size_t sizeInBytes, size_t* sizeWritten); /** Reduction scheme for portions of the dot-product calculated in parallel (a. k. a. "split - K"). */ typedef enum { /** No reduction scheme, dot-product shall be performed in one sequence. */ CUBLASLT_REDUCTION_SCHEME_NONE = 0, /** Reduction is performed "in place" - using the output buffer (and output data type) and counters (in workspace) to * guarantee the sequentiality. */ CUBLASLT_REDUCTION_SCHEME_INPLACE = 1, /** Intermediate results are stored in compute type in the workspace and reduced in a separate step. */ CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE = 2, /** Intermediate results are stored in output type in the workspace and reduced in a separate step. */ CUBLASLT_REDUCTION_SCHEME_OUTPUT_TYPE = 4, CUBLASLT_REDUCTION_SCHEME_MASK = 0x7, } cublasLtReductionScheme_t; /** Postprocessing options for the epilogue */ typedef enum { /** No special postprocessing, just scale and quantize results if necessary. */ CUBLASLT_EPILOGUE_DEFAULT = 1, /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)). */ CUBLASLT_EPILOGUE_RELU = 2, /** ReLu, apply ReLu point-wise transform to the results (x:=max(x, 0)). * * This epilogue mode produces an extra output, a ReLu bit-mask matrix, * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_RELU_AUX = (CUBLASLT_EPILOGUE_RELU | 128), /** Bias, apply (broadcasted) Bias from bias vector. Bias vector length must match matrix D rows, it must be packed * (stride between vector elements is 1). Bias vector is broadcasted to all columns and added before applying final * postprocessing. */ CUBLASLT_EPILOGUE_BIAS = 4, /** ReLu and Bias, apply Bias and then ReLu transform */ CUBLASLT_EPILOGUE_RELU_BIAS = (CUBLASLT_EPILOGUE_RELU | CUBLASLT_EPILOGUE_BIAS), /** ReLu and Bias, apply Bias and then ReLu transform * * This epilogue mode produces an extra output, a ReLu bit-mask matrix, * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_RELU_AUX_BIAS = (CUBLASLT_EPILOGUE_RELU_AUX | CUBLASLT_EPILOGUE_BIAS), /* ReLu gradient. Apply ReLu gradient to matmul output. Store ReLu gradient in the output matrix. * * This epilogue mode requires an extra input, * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_DRELU = 8 | 128, /* ReLu and Bias gradients. Apply independently ReLu and Bias gradient to * matmul output. Store ReLu gradient in the output matrix, and Bias gradient * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). * * This epilogue mode requires an extra input, * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_DRELU_BGRAD = CUBLASLT_EPILOGUE_DRELU | 16, /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)). */ CUBLASLT_EPILOGUE_GELU = 32, /** GELU, apply GELU point-wise transform to the results (x:=GELU(x)). * * This epilogue mode outputs GELU input as a separate matrix (useful for training). * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_GELU_AUX = (CUBLASLT_EPILOGUE_GELU | 128), /** GELU and Bias, apply Bias and then GELU transform */ CUBLASLT_EPILOGUE_GELU_BIAS = (CUBLASLT_EPILOGUE_GELU | CUBLASLT_EPILOGUE_BIAS), /** GELU and Bias, apply Bias and then GELU transform * * This epilogue mode outputs GELU input as a separate matrix (useful for training). * See CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_GELU_AUX_BIAS = (CUBLASLT_EPILOGUE_GELU_AUX | CUBLASLT_EPILOGUE_BIAS), /* GELU gradient. Apply GELU gradient to matmul output. Store GELU gradient in the output matrix. * * This epilogue mode requires an extra input, * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_DGELU = 64 | 128, /* GELU and Bias gradients. Apply independently GELU and Bias gradient to * matmul output. Store GELU gradient in the output matrix, and Bias gradient * in the auxiliary output (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). * * This epilogue mode requires an extra input, * see CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER. */ CUBLASLT_EPILOGUE_DGELU_BGRAD = CUBLASLT_EPILOGUE_DGELU | 16, /** Bias gradient based on the input matrix A. * * The bias size corresponds to the number of rows of the matrix D. * The reduction happens over the GEMM's "k" dimension. * * Stores Bias gradient in the auxiliary output * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). */ CUBLASLT_EPILOGUE_BGRADA = 256, /** Bias gradient based on the input matrix B. * * The bias size corresponds to the number of columns of the matrix D. * The reduction happens over the GEMM's "k" dimension. * * Stores Bias gradient in the auxiliary output * (see CUBLASLT_MATMUL_DESC_BIAS_POINTER). */ CUBLASLT_EPILOGUE_BGRADB = 512, } cublasLtEpilogue_t; /** Matmul heuristic search mode */ typedef enum { /** ask heuristics for best algo for given usecase */ CUBLASLT_SEARCH_BEST_FIT = 0, /** only try to find best config for preconfigured algo id */ CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID = 1, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_02 = 2, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_03 = 3, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_04 = 4, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_05 = 5, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_06 = 6, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_07 = 7, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_08 = 8, /** reserved for future use */ CUBLASLT_SEARCH_RESERVED_09 = 9, } cublasLtMatmulSearch_t; /** Algo search preference to fine tune the heuristic function. */ typedef enum { /** Search mode, see cublasLtMatmulSearch_t. * * uint32_t, default: CUBLASLT_SEARCH_BEST_FIT */ CUBLASLT_MATMUL_PREF_SEARCH_MODE = 0, /** Maximum allowed workspace size in bytes. * * uint64_t, default: 0 - no workspace allowed */ CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES = 1, /** Reduction scheme mask, see cublasLtReductionScheme_t. Filters heuristic result to only include algo configs that * use one of the required modes. * * E.g. mask value of 0x03 will allow only INPLACE and COMPUTE_TYPE reduction schemes. * * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_MASK (allows all reduction schemes) */ CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK = 3, /** Minimum buffer alignment for matrix A (in bytes). * * Selecting a smaller value will exclude algorithms that can not work with matrix A that is not as strictly aligned * as they need. * * uint32_t, default: 256 */ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES = 5, /** Minimum buffer alignment for matrix B (in bytes). * * Selecting a smaller value will exclude algorithms that can not work with matrix B that is not as strictly aligned * as they need. * * uint32_t, default: 256 */ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES = 6, /** Minimum buffer alignment for matrix C (in bytes). * * Selecting a smaller value will exclude algorithms that can not work with matrix C that is not as strictly aligned * as they need. * * uint32_t, default: 256 */ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES = 7, /** Minimum buffer alignment for matrix D (in bytes). * * Selecting a smaller value will exclude algorithms that can not work with matrix D that is not as strictly aligned * as they need. * * uint32_t, default: 256 */ CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES = 8, /** Maximum wave count. * * See cublasLtMatmulHeuristicResult_t::wavesCount. * * Selecting a non-zero value will exclude algorithms that report device utilization higher than specified. * * float, default: 0.0f */ CUBLASLT_MATMUL_PREF_MAX_WAVES_COUNT = 9, /** Numerical implementation details mask, see cublasLtNumericalImplFlags_t. Filters heuristic result to only include * algorithms that use the allowed implementations. * * uint64_t, default: uint64_t(-1) (allow everything) */ CUBLASLT_MATMUL_PREF_IMPL_MASK = 12, } cublasLtMatmulPreferenceAttributes_t; /** Internal. Do not use directly. */ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceInit_internal(cublasLtMatmulPreference_t pref, size_t size); /** Initialize matmul heuristic search preference descriptor in pre-allocated space. * * \retval CUBLAS_STATUS_ALLOC_FAILED if size of the pre-allocated space is insufficient * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ static inline cublasStatus_t cublasLtMatmulPreferenceInit(cublasLtMatmulPreference_t pref) { return cublasLtMatmulPreferenceInit_internal(pref, sizeof(*pref)); } /** Create new matmul heuristic search preference descriptor. * * \retval CUBLAS_STATUS_ALLOC_FAILED if memory could not be allocated * \retval CUBLAS_STATUS_SUCCESS if desciptor was created successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceCreate(cublasLtMatmulPreference_t* pref); /** Destroy matmul heuristic search preference descriptor. * * \retval CUBLAS_STATUS_SUCCESS if operation was successful */ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceDestroy(cublasLtMatmulPreference_t pref); /** Set matmul heuristic search preference descriptor attribute. * * \param[in] pref The descriptor * \param[in] attr The attribute * \param[in] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceSetAttribute( // cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr, const void* buf, size_t sizeInBytes); /** Get matmul heuristic search preference descriptor attribute. * * \param[in] pref The descriptor * \param[in] attr The attribute * \param[out] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents * * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero * and buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory */ cublasStatus_t CUBLASWINAPI cublasLtMatmulPreferenceGetAttribute( // cublasLtMatmulPreference_t pref, cublasLtMatmulPreferenceAttributes_t attr, void* buf, size_t sizeInBytes, size_t* sizeWritten); /** Results structure used by cublasLtMatmulAlgoGetHeuristic * * Holds returned configured algo descriptor and its runtime properties. */ typedef struct { /** Matmul algorithm descriptor. * * Must be initialized with cublasLtMatmulAlgoInit() if preferences' CUBLASLT_MATMUL_PERF_SEARCH_MODE is set to * CUBLASLT_SEARCH_LIMITED_BY_ALGO_ID */ cublasLtMatmulAlgo_t algo; /** Actual size of workspace memory required. */ size_t workspaceSize; /** Result status, other fields are only valid if after call to cublasLtMatmulAlgoGetHeuristic() this member is set to * CUBLAS_STATUS_SUCCESS. */ cublasStatus_t state; /** Waves count - a device utilization metric. * * wavesCount value of 1.0f suggests that when kernel is launched it will fully occupy the GPU. */ float wavesCount; int reserved[4]; } cublasLtMatmulHeuristicResult_t; /** Query cublasLt heuristic for algorithm appropriate for given use case. * * \param[in] lightHandle Pointer to the allocated cuBLASLt handle for the cuBLASLt * context. See cublasLtHandle_t. * \param[in] operationDesc Handle to the matrix multiplication descriptor. * \param[in] Adesc Handle to the layout descriptors for matrix A. * \param[in] Bdesc Handle to the layout descriptors for matrix B. * \param[in] Cdesc Handle to the layout descriptors for matrix C. * \param[in] Ddesc Handle to the layout descriptors for matrix D. * \param[in] preference Pointer to the structure holding the heuristic search * preferences descriptor. See cublasLtMatrixLayout_t. * \param[in] requestedAlgoCount Size of heuristicResultsArray (in elements) and requested * maximum number of algorithms to return. * \param[in, out] heuristicResultsArray Output algorithms and associated runtime characteristics, * ordered in increasing estimated compute time. * \param[out] returnAlgoCount The number of heuristicResultsArray elements written. * * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero * \retval CUBLAS_STATUS_NOT_SUPPORTED if no heuristic function available for current configuration * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect * heuristicResultsArray[0 to (returnAlgoCount - 1)].state * for detail status of results */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetHeuristic(cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc, cublasLtMatmulPreference_t preference, int requestedAlgoCount, cublasLtMatmulHeuristicResult_t heuristicResultsArray[], int* returnAlgoCount); /* ---------------------------------------------------------------------------------------*/ /* Lower level API to be able to implement own Heuristic and Find routines */ /* ---------------------------------------------------------------------------------------*/ /** Routine to get all algo IDs that can potentially run * * \param[in] int requestedAlgoCount requested number of algos (must be less or equal to size of algoIdsA * (in elements)) \param[out] algoIdsA array to write algoIds to \param[out] returnAlgoCount number of algoIds * actually written * * \retval CUBLAS_STATUS_INVALID_VALUE if requestedAlgoCount is less or equal to zero * \retval CUBLAS_STATUS_SUCCESS if query was successful, inspect returnAlgoCount to get actual number of IDs * available */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoGetIds(cublasLtHandle_t lightHandle, cublasComputeType_t computeType, cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype, cudaDataType_t Ctype, cudaDataType_t Dtype, int requestedAlgoCount, int algoIdsArray[], int* returnAlgoCount); /** Initialize algo structure * * \retval CUBLAS_STATUS_INVALID_VALUE if algo is NULL or algoId is outside of recognized range * \retval CUBLAS_STATUS_NOT_SUPPORTED if algoId is not supported for given combination of data types * \retval CUBLAS_STATUS_SUCCESS if the structure was successfully initialized */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoInit(cublasLtHandle_t lightHandle, cublasComputeType_t computeType, cudaDataType_t scaleType, cudaDataType_t Atype, cudaDataType_t Btype, cudaDataType_t Ctype, cudaDataType_t Dtype, int algoId, cublasLtMatmulAlgo_t* algo); /** Check configured algo descriptor for correctness and support on current device. * * Result includes required workspace size and calculated wave count. * * CUBLAS_STATUS_SUCCESS doesn't fully guarantee algo will run (will fail if e.g. buffers are not correctly aligned); * but if cublasLtMatmulAlgoCheck fails, the algo will not run. * * \param[in] algo algo configuration to check * \param[out] result result structure to report algo runtime characteristics; algo field is never updated * * \retval CUBLAS_STATUS_INVALID_VALUE if matrix layout descriptors or operation descriptor don't match algo * descriptor * \retval CUBLAS_STATUS_NOT_SUPPORTED if algo configuration or data type combination is not currently supported on * given device * \retval CUBLAS_STATUS_ARCH_MISMATCH if algo configuration cannot be run using the selected device * \retval CUBLAS_STATUS_SUCCESS if check was successful */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCheck( // cublasLtHandle_t lightHandle, cublasLtMatmulDesc_t operationDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc, const cublasLtMatmulAlgo_t* algo, ///< may point to result->algo cublasLtMatmulHeuristicResult_t* result); /** Capabilities Attributes that can be retrieved from an initialized Algo structure */ typedef enum { /** support for split K, see CUBLASLT_ALGO_CONFIG_SPLITK_NUM * * int32_t, 0 means no support, supported otherwise */ CUBLASLT_ALGO_CAP_SPLITK_SUPPORT = 0, /** reduction scheme mask, see cublasLtReductionScheme_t; shows supported reduction schemes, if reduction scheme is * not masked out it is supported. * * e.g. int isReductionSchemeComputeTypeSupported ? (reductionSchemeMask & CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE) == * CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE ? 1 : 0; * * uint32_t */ CUBLASLT_ALGO_CAP_REDUCTION_SCHEME_MASK = 1, /** support for cta swizzling, see CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING * * uint32_t, 0 means no support, 1 means supported value of 1, other values are reserved */ CUBLASLT_ALGO_CAP_CTA_SWIZZLING_SUPPORT = 2, /** support strided batch * * int32_t, 0 means no support, supported otherwise */ CUBLASLT_ALGO_CAP_STRIDED_BATCH_SUPPORT = 3, /** support results out of place (D != C in D = alpha.A.B + beta.C) * * int32_t, 0 means no support, supported otherwise */ CUBLASLT_ALGO_CAP_OUT_OF_PLACE_RESULT_SUPPORT = 4, /** syrk/herk support (on top of regular gemm) * * int32_t, 0 means no support, supported otherwise */ CUBLASLT_ALGO_CAP_UPLO_SUPPORT = 5, /** tile ids possible to use, see cublasLtMatmulTile_t; if no tile ids are supported use * CUBLASLT_MATMUL_TILE_UNDEFINED * * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count * * array of uint32_t */ CUBLASLT_ALGO_CAP_TILE_IDS = 6, /** custom option range is from 0 to CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX (inclusive), see * CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION * * int32_t */ CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX = 7, /** whether algorithm supports custom (not COL or ROW memory order), see cublasLtOrder_t * * int32_t 0 means only COL and ROW memory order is allowed, non-zero means that algo might have different * requirements; */ CUBLASLT_ALGO_CAP_CUSTOM_MEMORY_ORDER = 10, /** bitmask enumerating pointer modes algorithm supports * * uint32_t, see cublasLtPointerModeMask_t */ CUBLASLT_ALGO_CAP_POINTER_MODE_MASK = 11, /** bitmask enumerating kinds of postprocessing algorithm supports in the epilogue * * uint32_t, see cublasLtEpilogue_t */ CUBLASLT_ALGO_CAP_EPILOGUE_MASK = 12, /** stages ids possible to use, see cublasLtMatmulStages_t; if no stages ids are supported use * CUBLASLT_MATMUL_STAGES_UNDEFINED * * use cublasLtMatmulAlgoCapGetAttribute() with sizeInBytes=0 to query actual count * * array of uint32_t */ CUBLASLT_ALGO_CAP_STAGES_IDS = 13, /** support for nagative ld for all of the matrices * * int32_t 0 means no support, supported otherwise */ CUBLASLT_ALGO_CAP_LD_NEGATIVE = 14, /** details about algorithm's implementation that affect it's numerical behavior * * uint64_t, see cublasLtNumericalImplFlags_t */ CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS = 15, /** minimum alignment required for A matrix in bytes * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) * * uint32_t */ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_A_BYTES = 16, /** minimum alignment required for B matrix in bytes * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) * * uint32_t */ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_B_BYTES = 17, /** minimum alignment required for C matrix in bytes * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) * * uint32_t */ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_C_BYTES = 18, /** minimum alignment required for D matrix in bytes * (required for buffer pointer, leading dimension, and possibly other strides defined for matrix memory order) * * uint32_t */ CUBLASLT_ALGO_CAP_MIN_ALIGNMENT_D_BYTES = 19, /** EXPERIMENTAL: support for synchronization via atomic counters * * int32_t */ CUBLASLT_ALGO_CAP_ATOMIC_SYNC = 20, } cublasLtMatmulAlgoCapAttributes_t; /** Get algo capability attribute. * * E.g. to get list of supported Tile IDs: * cublasLtMatmulTile_t tiles[CUBLASLT_MATMUL_TILE_END]; * size_t num_tiles, size_written; * if (cublasLtMatmulAlgoCapGetAttribute(algo, CUBLASLT_ALGO_CAP_TILE_IDS, tiles, sizeof(tiles), size_written) == * CUBLAS_STATUS_SUCCESS) { num_tiles = size_written / sizeof(tiles[0]); * } * * \param[in] algo The algo descriptor * \param[in] attr The attribute * \param[out] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents * * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero * and buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoCapGetAttribute(const cublasLtMatmulAlgo_t* algo, cublasLtMatmulAlgoCapAttributes_t attr, void* buf, size_t sizeInBytes, size_t* sizeWritten); /** Algo Configuration Attributes that can be set according to the Algo capabilities */ typedef enum { /** algorithm index, see cublasLtMatmulAlgoGetIds() * * readonly, set by cublasLtMatmulAlgoInit() * int32_t */ CUBLASLT_ALGO_CONFIG_ID = 0, /** tile id, see cublasLtMatmulTile_t * * uint32_t, default: CUBLASLT_MATMUL_TILE_UNDEFINED */ CUBLASLT_ALGO_CONFIG_TILE_ID = 1, /** Number of K splits. If the number of K splits is greater than one, SPLITK_NUM parts * of matrix multiplication will be computed in parallel. The results will be accumulated * according to CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME * * int32_t, default: 1 */ CUBLASLT_ALGO_CONFIG_SPLITK_NUM = 2, /** reduction scheme, see cublasLtReductionScheme_t * * uint32_t, default: CUBLASLT_REDUCTION_SCHEME_NONE */ CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME = 3, /** cta swizzling, change mapping from CUDA grid coordinates to parts of the matrices * * possible values: 0, 1, other values reserved * * uint32_t, default: 0 */ CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING = 4, /** custom option, each algorithm can support some custom options that don't fit description of the other config * attributes, see CUBLASLT_ALGO_CAP_CUSTOM_OPTION_MAX to get accepted range for any specific case * * uint32_t, default: 0 */ CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION = 5, /** stages id, see cublasLtMatmulStages_t * * uint32_t, default: CUBLASLT_MATMUL_STAGES_UNDEFINED */ CUBLASLT_ALGO_CONFIG_STAGES_ID = 6, /** inner shape id, see cublasLtMatmulInnerShape_t * * uint16_t, default: 0 (CUBLASLT_MATMUL_INNER_SHAPE_UNDEFINED) */ CUBLASLT_ALGO_CONFIG_INNER_SHAPE_ID = 7, /** Thread Block Cluster shape id, see cublasLtClusterShape_t. Defines cluster size to use. * * uint16_t, default: 0 (CUBLASLT_CLUSTER_SHAPE_AUTO) */ CUBLASLT_ALGO_CONFIG_CLUSTER_SHAPE_ID = 8, } cublasLtMatmulAlgoConfigAttributes_t; /** Set algo configuration attribute. * * \param[in] algo The algo descriptor * \param[in] attr The attribute * \param[in] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * * \retval CUBLAS_STATUS_INVALID_VALUE if buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigSetAttribute(cublasLtMatmulAlgo_t* algo, cublasLtMatmulAlgoConfigAttributes_t attr, const void* buf, size_t sizeInBytes); /** Get algo configuration attribute. * * \param[in] algo The algo descriptor * \param[in] attr The attribute * \param[out] buf memory address containing the new value * \param[in] sizeInBytes size of buf buffer for verification (in bytes) * \param[out] sizeWritten only valid when return value is CUBLAS_STATUS_SUCCESS. If sizeInBytes is non-zero: number of * bytes actually written, if sizeInBytes is 0: number of bytes needed to write full contents * * \retval CUBLAS_STATUS_INVALID_VALUE if sizeInBytes is 0 and sizeWritten is NULL, or if sizeInBytes is non-zero * and buf is NULL or sizeInBytes doesn't match size of internal storage for * selected attribute * \retval CUBLAS_STATUS_SUCCESS if attribute's value was successfully written to user memory */ cublasStatus_t CUBLASWINAPI cublasLtMatmulAlgoConfigGetAttribute(const cublasLtMatmulAlgo_t* algo, cublasLtMatmulAlgoConfigAttributes_t attr, void* buf, size_t sizeInBytes, size_t* sizeWritten); /** Experimental: Logger callback type. */ typedef void (*cublasLtLoggerCallback_t)(int logLevel, const char* functionName, const char* message); /** Experimental: Logger callback setter. * * \param[in] callback a user defined callback function to be called by the logger * * \retval CUBLAS_STATUS_SUCCESS if callback was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetCallback(cublasLtLoggerCallback_t callback); /** Experimental: Log file setter. * * \param[in] file an open file with write permissions * * \retval CUBLAS_STATUS_SUCCESS if log file was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetFile(FILE* file); /** Experimental: Open log file. * * \param[in] logFile log file path. if the log file does not exist, it will be created * * \retval CUBLAS_STATUS_SUCCESS if log file was created successfully */ cublasStatus_t CUBLASWINAPI cublasLtLoggerOpenFile(const char* logFile); /** Experimental: Log level setter. * * \param[in] level log level, should be one of the following: * 0. Off * 1. Errors * 2. Performance Trace * 3. Performance Hints * 4. Heuristics Trace * 5. API Trace * * \retval CUBLAS_STATUS_INVALID_VALUE if log level is not one of the above levels * * \retval CUBLAS_STATUS_SUCCESS if log level was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetLevel(int level); /** Experimental: Log mask setter. * * \param[in] mask log mask, should be a combination of the following masks: * 0. Off * 1. Errors * 2. Performance Trace * 4. Performance Hints * 8. Heuristics Trace * 16. API Trace * * \retval CUBLAS_STATUS_SUCCESS if log mask was set successfully */ cublasStatus_t CUBLASWINAPI cublasLtLoggerSetMask(int mask); /** Experimental: Disable logging for the entire session. * * \retval CUBLAS_STATUS_SUCCESS if disabled logging */ cublasStatus_t CUBLASWINAPI cublasLtLoggerForceDisable(); #if defined(__cplusplus) } #endif /* __cplusplus */