/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include "./FbgemmBuild.h" // @manual #include "./QuantUtilsAvx2.h" // @manual #include "./QuantUtilsNeon.h" // @manual #include "./Types.h" // @manual #include "./Utils.h" // @manual #include #include #include #include #include /// @defgroup fbgemm-quant-utils-generic Quantization Utilities (Generic) /// namespace fbgemm { FBGEMM_API TensorQuantizationParams ChooseQuantizationParams( float min, float max, std::int32_t qmin, std::int32_t qmax, bool preserve_sparsity = false, bool force_scale_power_of_two = false); FBGEMM_API void ChooseRequantizationMultiplier( float real_multiplier, std::int32_t* quantized_multiplier, int* right_shift, int requantization_multiplier_precision = 32); //////////////////////////////////////////////////////////////////////////////// // Utility functions // Clamp src in T1 to the desired precision and convert it to T2 // TODO: T26263653 fix signed-integer-overflow undefined behavior template NO_SANITIZE("signed-integer-overflow") T2 clamp(T1 src, int precision, bool is_signed = false) { std::int32_t min = is_signed ? -(1LL << (precision - 1)) : 0; std::int32_t max = is_signed ? ((1LL << (precision - 1)) - 1) : (1LL << precision) - 1; // Make sure T1 and T2 can represent the precision assert(min >= std::numeric_limits::lowest()); assert(min >= std::numeric_limits::lowest()); assert(max <= std::numeric_limits::max()); assert(max <= std::numeric_limits::max()); return std::min(std::max(src, min), max); } /// Quantize src using zero_point and scale, clamp to the specified precision, /// and convert it to type T template T Quantize( float src, std::int32_t zero_point, float scale, int result_precision, bool result_is_signed = std::is_signed_v) { // Note: We want to multiply with src with inv_scale instead of // dividing src by scale. The same is done in vector code and // at other places. // // Example: // With scale = 0.00214854861f, zero_point = 0 and src = 0.273939937f // transformed_val is 127.5 for src * inv_scale while // transformed_val is 127.499992 for src / scale. // Eventually 127.5 gets rounded to 128 while 127.499992 gets rounded to 127. float inv_scale = 1.0f / scale; float transformed_val = src * inv_scale; // nearbyint here performs round-to-nearest-ties-to-even with // default rounding mode. // For example, nearbyint(1.4) is 1.0, nearbyint(1.5) is 2.0 // and nearbyint(2.5) is 2.0 // Adding zero_point before or after rounding can make a difference // in exactly halfway cases. if constexpr (LEGACY) { transformed_val = std::nearbyint(zero_point + transformed_val); } else { transformed_val = zero_point + std::nearbyint(transformed_val); } // Please note the use of double. Unlike float, a double can represent // all int32 values exactly. Using a float results in a float value > // INT32_MAX conversion to int32 in clamp function and hence an UBSAN error. return clamp(transformed_val, result_precision, result_is_signed); } template T Quantize(float src, const TensorQuantizationParams& qparams) { return Quantize( src, qparams.zero_point, qparams.scale, qparams.precision); } template FBGEMM_API void Quantize( const float* src, T* dst, std::int64_t len, const TensorQuantizationParams& qparams, int thread_id = 0, int num_threads = 1); /// @ingroup fbgemm-quant-utils-generic /// /// Quantize floating point data in `src` to type `T`. /// /// @tparam T output quantized data type (`int8_t`, `uint8_t`, and `int32_t` are /// supported) /// /// @tparam LAYOUT layout of input tensor in `src`. (`KCX` and `KXC` are /// supported) /// `KCX` corresponds to `KCRS` or `KCTRS` (for weight tensors with time /// dimension) /// `KXC` corresponds to `KRSC` or `KTRSC` (for weight tensors with time /// dimension) /// /// @param K Output channels for weight tensors /// @param C Number of channels /// @param X `R*S` or `T*R*S` /// @param G Groups (if `G == C` the function performs channelwise /// quantization; /// if `1 < G < C` the function performs groupwise /// quantization; if `G == 1` the function performs per tensor /// quantization;) /// @param scales floating point scales. Size should be equal `G` /// @param zero_points zero points (should be reprsentable in type `T`). /// Size should be equal `G` template FBGEMM_API void QuantizeGroupwise( const float* src, int K, int C, int X, int G, const float* scales, const std::int32_t* zero_points, T* dst); template float Dequantize(T src, const TensorQuantizationParams& qparams) { return qparams.scale * (src - qparams.zero_point); } template void Dequantize( const T* src, float* dst, std::int64_t len, const TensorQuantizationParams& qparams, int thread_id = 0, int num_threads = 1) { int64_t i_begin = 0, i_end = 0; fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); for (int64_t i = i_begin; i < i_end; i++) { dst[i] = Dequantize(src[i], qparams); } } template float FusedQuantizeDequantize( float src, const TensorQuantizationParams& qparams) { T q = Quantize( src, qparams.zero_point, qparams.scale, qparams.precision); return Dequantize(q, qparams); } /// @ingroup fbgemm-quant-utils-generic /// /// Fused integer quantization dequantization kernel to accelerate /// quantization-aware training. Quantize `fp32` values in src to `(u)int8` /// using the provided qparams, and dequantize quantized integer values back /// into `fp32`. template FBGEMM_API void FusedQuantizeDequantize( const float* src, float* dst, std::int64_t len, const TensorQuantizationParams& qparams, int thread_id = 0, int num_threads = 1, float noise_ratio = 0.0f); //////////////////////////////////////////////////////////////////////////////// // Requantization (pure fixed-point) FBGEMM_API std::int64_t SaturatingRoundingMulWithShift(std::int32_t a, std::int32_t b, int right_shift); template T Requantize( std::int32_t src, // int32 input before requantization std::int32_t zero_point, std::int32_t multiplier, int right_shift, int result_precision, bool result_is_signed = false) { std::int64_t quantized_down = zero_point + SaturatingRoundingMulWithShift(src, multiplier, right_shift); return clamp( quantized_down, result_precision, result_is_signed); } template T RequantizeFixedPoint( std::int32_t src, // int32 input before requantization const RequantizationParams& params) { return Requantize( src, params.target_qparams.zero_point, params.multiplier, params.right_shift, params.target_qparams.precision); } template FBGEMM_API void RequantizeFixedPoint( const std::int32_t* src, T* dst, std::int64_t len, const RequantizationParams& params, int thread_id = 0, int num_threads = 1); //////////////////////////////////////////////////////////////////////////////// // Requantization (with floats) template T Requantize( std::int32_t src, // int32 input before requantization std::int32_t zero_point, float multiplier, int result_precision, bool result_is_signed = false) { long quantized_down = zero_point + std::lrintf(src * multiplier); return clamp(quantized_down, result_precision, result_is_signed); } template T Requantize( std::int32_t src, // int32 input before requantization const RequantizationParams& params) { return Requantize( src, params.target_qparams.zero_point, params.real_multiplier, params.target_qparams.precision); } template FBGEMM_API void Requantize( const std::int32_t* src, T* dst, std::int64_t len, const RequantizationParams& params, int thread_id = 0, int num_threads = 1); /** * @ingroup fbgemm-quant-utils-generic * * Convert float (fp32 or fp16) inputs to rowwise quantized outputs. * bitrate specifies the number of bits in quantized output. * Scale and Bias are in fp16. Each row's Scale and Bias are stored in * the row itself (fused) at the end. * * @param bit_rate can be 2, 4, or 8 */ template FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( int bit_rate, const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output); /** * Convert fused rowwise quantized inputs to float (fp32 or fp16). * bitrate specifies the number of bits in quantized input. * Scale and Bias are in fp16. Each row's Scale and Bias are stored in * the row itself (fused) at the end. * * @param bit_rate can be 2, 4, or 8 */ template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf( int bit_rate, const uint8_t* input, size_t input_rows, int input_columns, OutputType* output, bool scale_bias_last = true); /** * Convert float or half inputs to rowwise quantized (8-bit) outputs. * Scale and Bias are in float. Each row's Scale and Bias are stored in * the row itself (fused) at the end. * * This version intentionally supports only 8-bit because we want to discourage * the usage of float scale and bias with 2 and 4 bit cases as that diminishes * the overall memory savings. */ template FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output, const InputType* rowwise_min_max = nullptr); /** * Convert fused rowwise quantized (8-bit) inputs to float or half outputs. * Scale and Bias are in float. Each row's Scale and Bias are stored in * the row itself (fused) at the end. * * This version intentionally supports only 8-bit because * the corresponding quantize version only supports 8-bit. */ template FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf( const uint8_t* input, size_t input_rows, int input_columns, OutputType* output); /** * Same as ToFusedNBitRowwiseQuantizedSBHalf but unoptimized. * This should not be called directly except in testing. */ template FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef( int bit_rate, const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output); /** * Same as FloatOrHalfToFused8BitRowwiseQuantizedSBFloat but unoptimized. * This should not be called directly except in testing. */ template FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef( const InputType* input, size_t input_rows, int input_columns, std::uint8_t* output); /** * Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized. * This should not be called directly except in testing. */ template FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef( int bit_rate, const uint8_t* input, size_t input_rows, int input_columns, OutputType* output, bool scale_bias_last = true); /** * Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized. * This should not be called directly except in testing. */ template FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef( const uint8_t* input, size_t input_rows, int input_columns, OutputType* output); } // namespace fbgemm