Skip to content

Commit

Permalink
Add a bf16 datatype to sfnp (#914)
Browse files Browse the repository at this point in the history
Add bf16 datatype, functions, and tests
  • Loading branch information
KyleHerndon authored Feb 13, 2025
1 parent b3a5219 commit af8ee1c
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 0 deletions.
1 change: 1 addition & 0 deletions shortfin/python/array_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ class Refs {
add_type(DType::int64(), "q", sizeof(signed long long));
add_type(DType::sint64(), "q", sizeof(signed long long));
add_type(DType::uint64(), "Q", sizeof(unsigned long long));
add_type(DType::bfloat16(), "H", sizeof(unsigned short));
add_type(DType::float16(), "H", sizeof(unsigned short));
add_type(DType::float32(), "f", sizeof(float));
add_type(DType::float64(), "d", sizeof(double));
Expand Down
124 changes: 124 additions & 0 deletions shortfin/python/array_host_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,116 @@
#include "xtensor/xsort.hpp"
#include "xtl/xhalf_float.hpp"

#ifndef BFLOAT16_HPP
#define BFLOAT16_HPP

#include <bit>
#include <cstdint>
#include <limits>
#include <type_traits>

struct bfloat16_t {
uint16_t value;

constexpr bfloat16_t() noexcept : value(0) {}

explicit constexpr bfloat16_t(float f) noexcept {
uint32_t temp = std::bit_cast<uint32_t>(f);
value = static_cast<uint16_t>(temp >> 16);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T> &&
!std::is_same_v<T, float>>>
constexpr bfloat16_t(T value) noexcept
: bfloat16_t(static_cast<float>(value)) {}

constexpr operator float() const noexcept {
uint32_t temp = static_cast<uint32_t>(value) << 16;
return std::bit_cast<float>(temp);
}

// Arithmetic operators (implemented via conversion to float)
constexpr bfloat16_t operator+(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) + float(other));
}
constexpr bfloat16_t operator-(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) - float(other));
}
constexpr bfloat16_t operator*(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) * float(other));
}
constexpr bfloat16_t operator/(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) / float(other));
}

constexpr bfloat16_t &operator+=(const bfloat16_t &other) noexcept {
*this = *this + other;
return *this;
}
constexpr bfloat16_t &operator-=(const bfloat16_t &other) noexcept {
*this = *this - other;
return *this;
}
constexpr bfloat16_t &operator*=(const bfloat16_t &other) noexcept {
*this = *this * other;
return *this;
}
constexpr bfloat16_t &operator/=(const bfloat16_t &other) noexcept {
*this = *this / other;
return *this;
}

// Comparison operators (using conversion to float)
constexpr bool operator==(const bfloat16_t &other) const noexcept {
return float(*this) == float(other);
}
constexpr bool operator!=(const bfloat16_t &other) const noexcept {
return !(*this == other);
}
constexpr bool operator<(const bfloat16_t &other) const noexcept {
return float(*this) < float(other);
}
constexpr bool operator<=(const bfloat16_t &other) const noexcept {
return float(*this) <= float(other);
}
constexpr bool operator>(const bfloat16_t &other) const noexcept {
return float(*this) > float(other);
}
constexpr bool operator>=(const bfloat16_t &other) const noexcept {
return float(*this) >= float(other);
}
};

// Mark bfloat16_t as a trivial, standard-layout type so that xtensor can use
// it.
namespace std {
template <>
struct is_trivial<bfloat16_t> : std::true_type {};
template <>
struct is_standard_layout<bfloat16_t> : std::true_type {};
template <>
struct is_trivially_copyable<bfloat16_t> : std::true_type {};
} // namespace std

// Math functions needed by xtensor for bfloat16_t
inline constexpr bfloat16_t round(bfloat16_t x) noexcept {
return bfloat16_t(std::round(float(x)));
}

inline constexpr bfloat16_t ceil(bfloat16_t x) noexcept {
return bfloat16_t(std::ceil(float(x)));
}

inline constexpr bfloat16_t floor(bfloat16_t x) noexcept {
return bfloat16_t(std::floor(float(x)));
}

inline constexpr bfloat16_t trunc(bfloat16_t x) noexcept {
return bfloat16_t(std::trunc(float(x)));
}

#endif // BFLOAT16_HPP

using namespace shortfin::array;

namespace shortfin::python {
Expand Down Expand Up @@ -191,6 +301,7 @@ struct ConvertFunctor {
}
switch (dtype) {
SF_STORE_CASE(float16, half_float::half);
SF_STORE_CASE(bfloat16, bfloat16_t);
SF_STORE_CASE(float32, float);
SF_STORE_CASE(float64, double);
SF_STORE_CASE(uint8, uint8_t);
Expand All @@ -210,6 +321,7 @@ struct ConvertFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
SF_UNARY_THUNK_CASE(float64, double);
SF_UNARY_THUNK_CASE(uint8, uint8_t);
Expand Down Expand Up @@ -264,6 +376,7 @@ struct ConvertRoundFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -308,6 +421,7 @@ struct ConvertCeilFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -352,6 +466,7 @@ struct ConvertFloorFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -396,6 +511,7 @@ struct ConvertTruncFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -525,6 +641,11 @@ half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) {
return static_cast<half_float::half>(py::cast<double>(py_value));
}

bfloat16_t ConvertPyToEltTy(py::handle py_value, bfloat16_t zero) {
// Python can't cast directly to half so first go to double.
return static_cast<bfloat16_t>(py::cast<double>(py_value));
}

struct AddFunctor {
template <typename Lhs, typename Rhs>
static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
Expand Down Expand Up @@ -610,6 +731,7 @@ device_array ElementwiseOperation(py::handle lhs, py::handle rhs,

switch (dtype) {
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
SF_UNARY_FUNCTION_CASE(float32, float);
SF_UNARY_FUNCTION_CASE(float64, double);
SF_UNARY_FUNCTION_CASE(uint8, uint8_t);
Expand Down Expand Up @@ -661,6 +783,7 @@ void BindArrayHostOps(py::module_ &m) {

switch (input.dtype()) {
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
Expand Down Expand Up @@ -690,6 +813,7 @@ void BindArrayHostOps(py::module_ &m) {

switch (out.dtype()) {
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
Expand Down
15 changes: 15 additions & 0 deletions shortfin/tests/api/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_argmax_axis0(device):
@pytest.mark.parametrize(
"dtype",
[
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
],
Expand All @@ -114,6 +115,7 @@ def test_argmax_dtypes(device, dtype):
@pytest.mark.parametrize(
"dtype",
[
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
],
Expand All @@ -138,6 +140,7 @@ def test_fill_randn_default_generator(device, dtype):
@pytest.mark.parametrize(
"dtype",
[
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
],
Expand Down Expand Up @@ -180,6 +183,7 @@ def test_fill_randn_explicit_generator(device, dtype):
sfnp.int16,
sfnp.int32,
sfnp.int64,
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
sfnp.float64,
Expand Down Expand Up @@ -208,12 +212,16 @@ def round_half_away_from_zero(n):
@pytest.mark.parametrize(
"dtype,sfnp_func,ref_round_func",
[
(sfnp.bfloat16, sfnp.round, round_half_away_from_zero),
(sfnp.float16, sfnp.round, round_half_away_from_zero),
(sfnp.float32, sfnp.round, round_half_away_from_zero),
(sfnp.bfloat16, sfnp.ceil, math.ceil),
(sfnp.float16, sfnp.ceil, math.ceil),
(sfnp.float32, sfnp.ceil, math.ceil),
(sfnp.bfloat16, sfnp.floor, math.floor),
(sfnp.float16, sfnp.floor, math.floor),
(sfnp.float32, sfnp.floor, math.floor),
(sfnp.bfloat16, sfnp.trunc, math.trunc),
(sfnp.float16, sfnp.trunc, math.trunc),
(sfnp.float32, sfnp.trunc, math.trunc),
],
Expand Down Expand Up @@ -309,6 +317,8 @@ def test_elementwise_forms(device):
@pytest.mark.parametrize(
"lhs_dtype,rhs_dtype,promoted_dtype",
[
(sfnp.float32, sfnp.bfloat16, sfnp.float32),
(sfnp.bfloat16, sfnp.float32, sfnp.float32),
(sfnp.float32, sfnp.float16, sfnp.float32),
(sfnp.float16, sfnp.float32, sfnp.float32),
(sfnp.float32, sfnp.float64, sfnp.float64),
Expand Down Expand Up @@ -347,6 +357,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.add, 44.0),
(sfnp.uint32, sfnp.add, 44.0),
(sfnp.uint64, sfnp.add, 44.0),
(sfnp.bfloat16, sfnp.add, 44.0),
(sfnp.float16, sfnp.add, 44.0),
(sfnp.float32, sfnp.add, 44.0),
(sfnp.float64, sfnp.add, 44.0),
Expand All @@ -359,6 +370,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.divide, 21.0),
(sfnp.uint32, sfnp.divide, 21.0),
(sfnp.uint64, sfnp.divide, 21.0),
(sfnp.bfloat16, sfnp.divide, 21.0),
(sfnp.float16, sfnp.divide, 21.0),
(sfnp.float32, sfnp.divide, 21.0),
(sfnp.float64, sfnp.divide, 21.0),
Expand All @@ -371,6 +383,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.multiply, 84.0),
(sfnp.uint32, sfnp.multiply, 84.0),
(sfnp.uint64, sfnp.multiply, 84.0),
(sfnp.bfloat16, sfnp.multiply, 84.0),
(sfnp.float16, sfnp.multiply, 84.0),
(sfnp.float32, sfnp.multiply, 84.0),
(sfnp.float64, sfnp.multiply, 84.0),
Expand All @@ -383,6 +396,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.subtract, 40.0),
(sfnp.uint32, sfnp.subtract, 40.0),
(sfnp.uint64, sfnp.subtract, 40.0),
(sfnp.bfloat16, sfnp.subtract, 40.0),
(sfnp.float16, sfnp.subtract, 40.0),
(sfnp.float32, sfnp.subtract, 40.0),
(sfnp.float64, sfnp.subtract, 40.0),
Expand Down Expand Up @@ -418,6 +432,7 @@ def test_elementwise_array_correctness(device, dtype, op, check_value):
sfnp.uint32,
sfnp.uint64,
sfnp.float32,
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
sfnp.float64,
Expand Down

0 comments on commit af8ee1c

Please sign in to comment.