Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancements to support C++23 std::bfloat16_t type #1254

Open
johnplatts opened this issue Mar 23, 2023 · 1 comment
Open

Enhancements to support C++23 std::bfloat16_t type #1254

johnplatts opened this issue Mar 23, 2023 · 1 comment

Comments

@johnplatts
Copy link
Contributor

The upcoming C++23 standard adds support for the std::float16_t and std::bfloat16_t types, and the upcoming GCC 13 release will have support for the std::float16_t and std::bfloat16_t types with the -std=c++23 option.

The hwy::bfloat16_t type should also be updated to allow for implicit conversions to/from std::bfloat16_t in C++23 or later mode on platforms that support the std::bfloat16_t type.

Here is how support for conversions between std::bfloat16_t and hwy::bfloat16_t could be implemented:

#if HWY_HAS_INCLUDE(<version>)
#include <version>
#endif

#if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926
#define HWY_HAS_BUILTIN_BIT_CAST 1
#define HWY_HAS_CONSTEXPR_BIT_CAST 1
#elif defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L && \
      HWY_HAS_INCLUDE(<bit>)
#define HWY_HAS_BUILTIN_BIT_CAST 0
#define HWY_HAS_CONSTEXPR_BIT_CAST 1
#include <bit>
#else
#define HWY_HAS_BUILTIN_BIT_CAST 0
#define HWY_HAS_CONSTEXPR_BIT_CAST 0
#endif

#if __cplusplus >= 202100L && HWY_HAS_INCLUDE(<stdfloat>)
#define HWY_HAS_CXX23_STDFLOAT 1
#include <stdfloat>
#else
#define HWY_HAS_CXX23_STDFLOAT 0
#endif

#if HWY_HAS_CXX23_STDFLOAT && defined(__STDCPP_BFLOAT16_T__)
#define HWY_HAS_CXX23_BFLOAT16_T 1
#else
#define HWY_HAS_CXX23_BFLOAT16_T 0
#endif

namespace hwy {
// ... other definitions

template <typename To, typename From>
HWY_API
#if HWY_HAS_CONSTEXPR_BIT_CAST
constexpr
#endif
To ValueBitCast(const From& val) {
#if HWY_HAS_BUILTIN_BIT_CAST
  return __builtin_bit_cast(To, val);
#elif defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L
  return std::bit_cast<To>(val);
#else
  To result;
  CopySameSize(&val, &result);
  return result;
#endif
}

struct SpecialFloatFromBitsTag {
  constexpr SpecialFloatFromBitsTag() {};
};

namespace detail {
  HWY_INLINE
#if __cpp_constexpr >= 201304L
  constexpr
#endif
  uint16_t F32BitsToBF16Bits(uint32_t f32_bits) noexcept {
    uint32_t abs_f32_bits = f32_bits & 0x7FFFFFFFu;
    if(abs_f32_bits < 0x7F800000u) {
      f32_bits += 0x00007FFFu + ((f32_bits >> 16) & 1u);
    } else if(abs_f32_bits > 0x7F800000u) {
      f32_bits |= 0x00400000u;
    }
    return static_cast<uint16_t>(f32_bits >> 16);
  }
}

struct bfloat16_t {
  uint16_t bits;

  bfloat16_t() = default;
  HWY_INLINE constexpr bfloat16_t(SpecialFloatFromBitsTag /* tag */,
    uint16_t bf16_bits) noexcept : bits{bf16_bits} {};

  explicit
  HWY_INLINE
#if __cpp_constexpr >= 201304L && HWY_HAS_CONSTEXPR_BIT_CAST
  constexpr
#endif
  bfloat16_t(float val) noexcept :
    bits{detail::F32BitsToBF16Bits(ValueBitCast<uint32_t>(val))} {}
  
#if HWY_HAS_CXX23_BFLOAT16_T
  HWY_INLINE constexpr bfloat16_t(std::bfloat16_t bf16_val) noexcept :
    bits{ValueBitCast<uint16_t>(bf16_val)} {}

  HWY_INLINE constexpr operator std::bfloat16_t() const noexcept {
    return ValueBitCast<std::bfloat16_t>(bits);
  }
#endif

  HWY_INLINE
#if HWY_HAS_CONSTEXPR_BIT_CAST
  constexpr
#endif
  operator float() const noexcept {
    return ValueBitCast<float>(bits << 16);
  }
};

// ... other definitions

}  // namespace hwy

The updated implementation of hwy::bfloat16_t above will compile in C++11 or later mode, including with older C++11 compilers such as g++ 4.7.1 or clang 3.4.1.

The updated implementation of hwy::bfloat16_t also adds a hwy::bfloat16_t(SpecialFloatFromBitsTag, uint16_t) constructor to distinguish between constructing a hwy::bfloat16_t from its bit representation and a conversion to hwy::bfloat16_t from floating-point types.

The behavior of code such as the function below will change with the updated hwy::bfloat16 implementation above:

hwy::bfloat16_t SomeFuncThatReturnsABF16() {
  return hwy::bfloat16_t{0x4030};
}

SomeFuncThatReturnsABF16 will return a hwy::bfloat16_t with the value of 2.75 with the current implementation (which initializes a hwy::bfloat16_t from the BF16 bitwise representation of 2.75), whereas SomeFuncThatReturnsABF16 will return a hwy::bfloat16_t with the value of 16432.0 with the updated implementation of hwy::bfloat16_t (which converts the integer value of 16432.0 to a floating point).

The updated hwy::bfloat16_t implementation above will allow code such as the following in C++23 mode with GCC 13 or later (with updates to the Set function):

namespace example {
namespace HWY_NAMESPACE {

using namespace hwy;
using namespace hwy::HWY_NAMESPACE;

template<class V>
static HWY_INLINE auto InvertNonSignBitsIfNegative(V v) {
  const DFromV<decltype(v)> d;
  const RebindToSigned<decltype(d)> di;
  const RebindToUnsigned<decltype(d)> du;

  const auto vi = BitCast(di, v);
  const auto invert_mask = BitCast(di, ShiftRight<1>(
    BitCast(du, BroadcastSignBit(vi))));
  return Xor(vi, invert_mask);
}

template<class V>
static HWY_INLINE V FloatMinUsingIntCompare(V a, V b) {
  const DFromV<decltype(a)> d;

  return BitCast(d, InvertNonSignBitsIfNegative(Min(
    InvertNonSignBitsIfNegative(a),
    InvertNonSignBitsIfNegative(b))));
}

template<class V>
static HWY_INLINE V FloatMaxUsingIntCompare(V a, V b) {
  const DFromV<decltype(a)> d;

  return BitCast(d, InvertNonSignBitsIfNegative(Max(
    InvertNonSignBitsIfNegative(a),
    InvertNonSignBitsIfNegative(b))));
}

void SomeFuncThatProcessesBFloat16(hwy::bfloat16_t* result_ptr,
  const hwy::bfloat16_t* src_ptr) {
  ScalableTag<hwy::bfloat16_t> d;
  const RebindToUnsigned<decltype(d)> du;
  const auto v = Load(d, src_ptr);
  const auto clamped = FloatMaxUsingIntCompare(
    FloatMinUsingIntCompare(v, Set(du, 2.75bf16)),
    Set(du, -5.375bf16));
}

}
}
@jan-wassenberg
Copy link
Member

Interesting, thanks for proposing this. I agree it would be useful to have support for Set() of bfloat16_t. Currently tests use float constants that we then DemoteTo, which is not very convenient.

It's great that your proposal still works in C++11, that is important for us.

Should we change hwy::bfloat16_t to a typedef to std::bfloat16_t (where available) instead of using a wrapper?
That is also the approach we take with float16_t, though last I checked that only worked reliably on Arm and RVV.

If you'd like to start preparing a pull request, let's first change the constexpr+#if to a HWY_CONSTEXPR_BIT_CAST and HWY_CONSTEXPR2 (or some better name), which either expands to constexpr or nothing depending on compiler capabilities. This would shorten the spots where it is used, and also help with code folding in IDEs because they might not be able to fold past an #if.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants