/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include "fbgemm/FbgemmBuild.h" namespace fbgemm { template < typename InType, typename IndexType, typename OffsetType = std::int32_t, typename OutType = float> class EmbeddingSpMDMKernelSignature { public: /** * Behavior is as the follow pseudocode * (when use_offsets == true, lengths[i] == offsets[i + 1] - offsets[i]) * (when is_weight_positional == true, use weights[j - offsets[i]] instead of * weights[j]) * * for i in range(output_size): * out[i * block_size : (i + 1) * block_size] = 0 * for j in range(offsets[i], offsets[i + 1]): * for k in range(block_size): * out[i * block_size + k] += input[indices[j] * block_size + k] * * weights ? weights[j] : 1; * if normalize_weights and lengths[i] > 0: * out[i * block_size : (i + 1) * block_size] /= lengths[i] * * @param data_size the number of rows in embedding table */ using Type = std::function; }; /** * @tparam InType can be float, float16, or uint8_t * @tparam IndexType can be int32_t or int64_t * @tparam IndexType can be int32_t or int64_t * * @param use_offsets If true, the generated code assumes we will pass offsets * instead of lengths that confirms PyTorch EmbeddingBag * interface. In this case, the length of offsets array * should be output_size + 1 and offsets[output_size] should * be index_size. * If false, the generate code assumes we will pass lengths * that confirms Caffe2 SparseLengthsSum interface. */ template < typename InType, typename IndexType, typename OffsetType = std::int32_t, typename OutType = float, bool THREAD_LOCAL = false> FBGEMM_API typename EmbeddingSpMDMKernelSignature< InType, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDM( const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true, bool is_bf16_out = false, bool is_bf16_in = false); /** * @param output_stride If -1, output_stride is same as block_size * @param input_stride If -1, input_stride is same as block_size * @param scale_bias_last if false, scale and bias appear at the beginning * of each row and are in fp16 for table batched embedding (TBE) * in FBGEMM_GPU. If false, it can also take -1 indices (output from * pruned embedding id mapping) */ template < typename InType, typename IndexType, typename OffsetType = std::int32_t, typename OutType = float, bool THREAD_LOCAL = false> FBGEMM_API typename EmbeddingSpMDMKernelSignature< InType, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMWithStrides( const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, bool scale_bias_last = true, bool no_bag = false, bool is_bf16_out = false, bool is_bf16_in = false); /** * @tparam IndexType can be int32_t or int64_t * @tparam OffsetType can be int32_t or int64_t * @param bit_rate can be 2 or 4 */ template < typename IndexType, typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API typename EmbeddingSpMDMKernelSignature< std::uint8_t, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMNBit( int bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @param output_stride If -1, output_stride is same as block_size * @param input_stride in Bytes. If -1, input_stride is same as * block_size / num_elem_per_byte + 2 * sizeof(float16) * @param scale_bias_last if false, scale and bias appear at the beginning * of each row and are in fp16 for table batched embedding (TBE) * in FBGEMM_GPU. If false, it can also take -1 indices (output from * pruned embedding id mapping) */ template < typename IndexType, typename OffsetType = std::int32_t, typename OutType = float, bool THREAD_LOCAL = false> FBGEMM_API typename EmbeddingSpMDMKernelSignature< std::uint8_t, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMNBitWithStrides( const int input_bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, bool scale_bias_last = true, const bool is_bf16_out = false, const bool no_bag = false, int output_bit_rate = -1); /** * @param output_stride If -1, output_stride is same as block_size * @param input_stride in Bytes. If -1, input_stride is same as * block_size / num_elem_per_byte + 2 * sizeof(float16) * @param exponent_bits is the number of exponent bits in the FP8 encode * (normally 4 or 5) * @param exponent_bias is subtracted from the exponent to obtain the actual * exponent for the floating-point number */ template < typename IndexType, typename OffsetType = std::int32_t, typename OutType = float> FBGEMM_API typename EmbeddingSpMDMKernelSignature< std::uint8_t, IndexType, OffsetType, OutType>::Type GenerateEmbeddingSpMDMFP8WithStrides( const std::int64_t block_size, bool normalize_by_lengths, bool is_weight_positional = false, bool use_offsets = true, std::int64_t output_stride = -1, std::int64_t input_stride = -1, int exponent_bits = 4, int exponent_bias = 7, bool is_bf16_out = false); template < typename InType, typename IndexType, typename OffsetType = std::int32_t> class EmbeddingSpMDMRowWiseSparseKernelSignature { public: using Type = std::function; }; /** * @tparam InType can be float, float16, or uint8_t * @tparam IndexType can be int32_t or int64_t * @tparam OffsetType can be int32_t or int64_t */ template < typename InType, typename IndexType, typename OffsetType = std::int32_t> FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< InType, IndexType, OffsetType>::Type GenerateEmbeddingSpMDMRowWiseSparse( const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @tparam IndexType can be int32_t or int64_t * @tparam OffsetType can be int32_t or int64_t * @param bit_rate can be 2 or 4 */ template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< std::uint8_t, IndexType, OffsetType>::Type GenerateEmbeddingSpMDMNBitRowWiseSparse( int bit_rate, const std::int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch = 16, bool is_weight_positional = false, bool use_offsets = true); /** * @return The number of rows processed. If smaller than num_rows, an error * must have happened at the last row processed. */ template class SparseAdaGradSignature { public: using Type = std::function; // frequency adjust happens only after }; template FBGEMM_API typename SparseAdaGradSignature::Type GenerateSparseAdaGrad( int block_size, // number of parameters per row bool rowwise = false, int prefetch = 16, bool use_weight_decay = false); // RowWiseSparseAdaGrad fused with SLS gradient // Weights can be either float or float16 template < typename IndexType, typename OffsetType = std::int32_t, typename DataType = float> class RowWiseSparseAdaGradFusedSignature { public: using Type = std::function; }; /** * @param grad_stride If -1, grad_stride is same as block size */ template < typename IndexType, typename OffsetType = std::int32_t, typename DataType = float> FBGEMM_API typename RowWiseSparseAdaGradFusedSignature< IndexType, OffsetType, DataType>::Type GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch = 16, bool use_offsets = true, bool use_stochastic_rounding = true, int grad_stride = -1); namespace internal { // Specialization for block size 1 internally called by GenerateEmbeddingSpMDM template FBGEMM_API bool EmbeddingSpMDMBlockSize1_( const std::int64_t output_size, const std::int64_t index_size, const std::int64_t data_size, // the number of rows in input const InType* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, float* out, bool is_weight_positional = false, bool use_offsets = true, bool is_bf16 = false); #if defined(FBGEMM_FBCODE) || !defined(__aarch64__) template void compressed_indices_remap_avx512( std::int32_t offsets_numel, const IndexType* indices, const int32_t* compressed_indices_mapping, const IndexType* offsets, const float* weights, // optional, can be null, IndexType* out_indices, IndexType* out_offsets, float* out_weights); #endif // Specialization for uint8_t* input on aarch64 called by GenerateEmbeddingSpMDM template < typename IndexType, typename OffsetType, typename OutType, bool NoBag, bool EnablePrefetching> FBGEMM_API bool EmbeddingSpMDM8Bit_Sve( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum const bool normalize_by_lengths, OutType* out, const bool is_weight_positional, const bool use_offsets, const int64_t output_stride, const int64_t input_stride, const bool scale_bias_last, const bool is_bf16_out); } // namespace internal template FBGEMM_API void compressed_indices_remap( std::int32_t offsets_numel, const IndexType* indices, const int32_t* compressed_indices_mapping, const IndexType* offsets, const float* weights, // optional, can be null, IndexType* out_indices, IndexType* out_offsets, float* out_weights); } // namespace fbgemm