/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include #include "./Types.h" // @manual #ifndef __is_identifier #define __is_identifier(x) 1 #endif #define __has_keyword(__x) !(__is_identifier(__x)) // TODO: we're disabling native fp16 on Windows to workaround test failures // due to "undefined symbol __gnu_h2f_ieee" error. We should follup on this // later. #if __has_keyword(__fp16) && !defined(_WIN32) #define HAS_NATIVE_FP16_TYPE using native_fp16_t = __fp16; #elif __has_keyword(_Float16) && !defined(_WIN32) #define HAS_NATIVE_FP16_TYPE using native_fp16_t = _Float16; #else using native_fp16_t = void; #endif namespace fbgemm { namespace detail { template struct FloatFormat { using value_type = T; static constexpr int bits = sizeof(T) * CHAR_BIT; static constexpr int exponent_bits = ExponentBits; static constexpr int mantissa_bits = bits - exponent_bits - 1; static constexpr int sign_bit_pos = bits - 1; static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1; static constexpr int unbiased_exponent_min = -exponent_bias + 1; static constexpr int unbiased_exponent_max = HasInfinity ? exponent_bias : (exponent_bias + 1); static constexpr T sign_bit = T{1} << sign_bit_pos; static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1) << mantissa_bits; static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1; // signaling/quiet encoding is unspecified by IEEE754. This mirrors x86/ARM. static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1); static constexpr T nan = exponent_mask | mantissa_mask; static constexpr T overflow_value = HasInfinity ? exponent_mask : nan; static constexpr bool has_infinity = HasInfinity; static constexpr bool has_nan_payload = HasInfinity; }; using IEEE754Single = FloatFormat; using IEEE754Half = FloatFormat; // See https://arxiv.org/abs/1905.12322v3 using BFloat16 = FloatFormat; // See https://doi.org/10.48550/arXiv.2209.05433 using FP8_E5M2 = FloatFormat; // See https://doi.org/10.48550/arXiv.2209.05433 using FP8_E4M3FN = FloatFormat< /*T=*/uint8_t, /*ExponentBits=*/4, /*HasInfinity=*/false>; enum class RoundingMode { ToNearestTiesToEven, ToZero, }; // Generic IEEE754 truncation algorithm. template [[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc( typename Src::value_type value) { static_assert(Src::exponent_bits >= Tgt::exponent_bits); static_assert(Src::mantissa_bits > Tgt::mantissa_bits); using ST = typename Src::value_type; using TT = typename Tgt::value_type; ST src_exponent = value & Src::exponent_mask; ST src_mantissa = value & Src::mantissa_mask; // Fast-path: If there is no difference in exponent sizes (e.g. fp32 -> bf16) // and we round toward zero, then we can just drop the least significant bits. if constexpr ( Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity && Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) { TT result = value >> (Src::bits - Tgt::bits); // Turn signaling NaN into quiet NaN. This also avoids that the mantissa // is completely zero after truncation (which would be misinterpreted as // INF). if (src_exponent == Src::exponent_mask && src_mantissa != 0) { result |= Tgt::quiet_nan_bit; } return result; } ST tgt_sign = (value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos); constexpr bool denormal_becomes_zero = Tgt::unbiased_exponent_min - Src::unbiased_exponent_min > Src::mantissa_bits - Tgt::mantissa_bits; if constexpr (denormal_becomes_zero) { // Fast-path for zero exponentbits: This means the number was zero or a // denormal number that will turn into zero in the Tgt format. if (src_exponent == 0) { return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0 } } int unbiased_exponent = (src_exponent >> Src::mantissa_bits) - Src::exponent_bias; if (unbiased_exponent < Tgt::unbiased_exponent_min) { int shift = Tgt::unbiased_exponent_min - unbiased_exponent; if (shift <= Tgt::mantissa_bits + 1) { // Result is denormal. ST src_mantissa_one = src_mantissa; // Add explicit one if the source was not denormal. if (denormal_becomes_zero || src_exponent != 0) { src_mantissa_one |= TT{1} << Src::mantissa_bits; } else { shift--; } TT tgt_mantissa = src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift); if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1; ST half = 1 << half_pos; ST remainder = src_mantissa_one & ((half << 1) - 1); if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) { tgt_mantissa += 1; } } else { static_assert(RoundingMode == RoundingMode::ToZero); } return tgt_sign | tgt_mantissa; // tgt_exponent == 0 } else { // Result is +/- zero return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0 } } if (unbiased_exponent > Tgt::unbiased_exponent_max) { if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) { TT tgt_mantissa; if constexpr (Tgt::has_nan_payload) { // NaN; not a number tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); tgt_mantissa |= Tgt::quiet_nan_bit; } else { tgt_mantissa = Tgt::mantissa_mask; } return tgt_sign | Tgt::exponent_mask | tgt_mantissa; } else { if (RoundingMode == RoundingMode::ToZero && (!Src::has_infinity || src_exponent != Src::exponent_mask)) { // Return largest finite number. return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) | Tgt::mantissa_mask; } // Infinity or NaN for formats without infinity. return tgt_sign | Tgt::overflow_value; } } // Normal number. TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias) << Tgt::mantissa_bits; if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1); ST remainder = src_mantissa & ((half << 1) - 1); if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) { if (tgt_mantissa < Tgt::mantissa_mask) { tgt_mantissa += 1; } else { // Mantissa overflowed, increment exponent. // Normally we can just add to the exponent and will naturally end up // on infinity on overflow. But we need special treatments for formats // without infinity. if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) { tgt_mantissa = 0; tgt_exponent += TT{1} << Tgt::mantissa_bits; } else { // Return NaN. tgt_mantissa = Tgt::mantissa_mask; } } } } else { static_assert(RoundingMode == RoundingMode::ToZero); } return tgt_sign | tgt_exponent | tgt_mantissa; } } // namespace detail inline float16 cpu_float2half_rn(float f) { uint32_t f_u32 = 0; std::memcpy(&f_u32, &f, sizeof(f_u32)); return detail::ieee754_trunc< /*Src=*/detail::IEEE754Single, /*Tgt=*/detail::IEEE754Half, detail::RoundingMode::ToNearestTiesToEven>(f_u32); } inline float16 cpu_float2half_rz(float f) { uint32_t f_u32 = 0; std::memcpy(&f_u32, &f, sizeof(f_u32)); return detail::ieee754_trunc< /*Src=*/detail::IEEE754Single, /*Tgt=*/detail::IEEE754Half, detail::RoundingMode::ToZero>(f_u32); } // Converts a 16-bit unsigned integer representation of a IEEE754 half-precision // float into an IEEE754 32-bit single-precision float inline float cpu_half2float_ref(const float16 h) { constexpr uint32_t f16_num_exponent_bits = 5; constexpr uint32_t f16_num_mantissa_bits = 10; constexpr uint32_t f16_num_non_sign_bits = f16_num_exponent_bits + f16_num_mantissa_bits; constexpr uint32_t f16_exponent_bias = 15; constexpr uint32_t f16_exponent_mask = 0b1'1111; constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; constexpr uint32_t f32_num_exponent_bits = 8; constexpr uint32_t f32_num_mantissa_bits = 23; constexpr uint32_t f32_num_non_sign_bits = f32_num_exponent_bits + f32_num_mantissa_bits; constexpr uint32_t f32_exponent_bias = 127; constexpr uint32_t f32_exponent_mask = 0b1111'1111; constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; constexpr uint32_t f32_most_significant_bit = 1u << 22; // Get sign and exponent alone by themselves uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1; uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask; // Shift mantissa so that it fills the most significant bits of a float32 uint32_t mantissa = (h & f16_mantissa_mask) << (f32_num_mantissa_bits - f16_num_mantissa_bits); if (exponent == f16_exponent_mask) { // NaN or Inf if (mantissa) { mantissa = f32_mantissa_mask; sign_bit = 0; } exponent = f32_exponent_mask; } else if (!exponent) { // Denorm or Zero if (mantissa) { uint32_t msb = 0; exponent = f32_exponent_bias - f16_exponent_bias + 1; do { msb = mantissa & f32_most_significant_bit; mantissa <<= 1; // normalize --exponent; } while (!msb); mantissa &= f32_mantissa_mask; // 1.mantissa is implicit } } else { exponent += f32_exponent_bias - f16_exponent_bias; } const uint32_t i = (sign_bit << f32_num_non_sign_bits) | (exponent << f32_num_mantissa_bits) | mantissa; float ret = NAN; std::memcpy(&ret, &i, sizeof(float)); return ret; } // Same as the previous function, but use the built-in fp16 to fp32 // conversion provided by the compiler inline float cpu_half2float(const float16 h) { #if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) __fp16 h_fp16 = NAN; std::memcpy(&h_fp16, &h, sizeof(__fp16)); return h_fp16; #else return cpu_half2float_ref(h); #endif } inline float16 cpu_float2half(const float f) { #if defined(HAS_NATIVE_FP16_TYPE) && not defined(MISSING_GNU_F2H_IEEE) __fp16 h = f; float16 res = 0; std::memcpy(&res, &h, sizeof(__fp16)); return res; #else return cpu_float2half_rn(f); #endif } inline float cpu_bf162float(bfloat16 src) { float ret = NAN; uint32_t val_fp32 = static_cast(reinterpret_cast(&src)[0]) << 16; std::memcpy(&ret, &val_fp32, sizeof(float)); return ret; } inline bfloat16 cpu_float2bfloat16(float src) { uint32_t temp = 0; std::memcpy(&temp, &src, sizeof(uint32_t)); return (temp + (1u << 15)) >> 16; } } // namespace fbgemm