From af8ee1c607dc2928029830d696af6c5f0a4b10e6 Mon Sep 17 00:00:00 2001 From: Kyle Herndon Date: Wed, 12 Feb 2025 17:43:32 -0800 Subject: [PATCH] Add a bf16 datatype to sfnp (#914) Add bf16 datatype, functions, and tests --- shortfin/python/array_binding.cc | 1 + shortfin/python/array_host_ops.cc | 124 +++++++++++++++++++++++++++ shortfin/tests/api/array_ops_test.py | 15 ++++ 3 files changed, 140 insertions(+) diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc index 4489eaeda..38974cc4c 100644 --- a/shortfin/python/array_binding.cc +++ b/shortfin/python/array_binding.cc @@ -202,6 +202,7 @@ class Refs { add_type(DType::int64(), "q", sizeof(signed long long)); add_type(DType::sint64(), "q", sizeof(signed long long)); add_type(DType::uint64(), "Q", sizeof(unsigned long long)); + add_type(DType::bfloat16(), "H", sizeof(unsigned short)); add_type(DType::float16(), "H", sizeof(unsigned short)); add_type(DType::float32(), "f", sizeof(float)); add_type(DType::float64(), "d", sizeof(double)); diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index 3e2a8ebe3..55e0546a3 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -12,6 +12,116 @@ #include "xtensor/xsort.hpp" #include "xtl/xhalf_float.hpp" +#ifndef BFLOAT16_HPP +#define BFLOAT16_HPP + +#include +#include +#include +#include + +struct bfloat16_t { + uint16_t value; + + constexpr bfloat16_t() noexcept : value(0) {} + + explicit constexpr bfloat16_t(float f) noexcept { + uint32_t temp = std::bit_cast(f); + value = static_cast(temp >> 16); + } + + template && + !std::is_same_v>> + constexpr bfloat16_t(T value) noexcept + : bfloat16_t(static_cast(value)) {} + + constexpr operator float() const noexcept { + uint32_t temp = static_cast(value) << 16; + return std::bit_cast(temp); + } + + // Arithmetic operators (implemented via conversion to float) + constexpr bfloat16_t operator+(const bfloat16_t &other) const noexcept { + return bfloat16_t(float(*this) + float(other)); + } + constexpr bfloat16_t operator-(const bfloat16_t &other) const noexcept { + return bfloat16_t(float(*this) - float(other)); + } + constexpr bfloat16_t operator*(const bfloat16_t &other) const noexcept { + return bfloat16_t(float(*this) * float(other)); + } + constexpr bfloat16_t operator/(const bfloat16_t &other) const noexcept { + return bfloat16_t(float(*this) / float(other)); + } + + constexpr bfloat16_t &operator+=(const bfloat16_t &other) noexcept { + *this = *this + other; + return *this; + } + constexpr bfloat16_t &operator-=(const bfloat16_t &other) noexcept { + *this = *this - other; + return *this; + } + constexpr bfloat16_t &operator*=(const bfloat16_t &other) noexcept { + *this = *this * other; + return *this; + } + constexpr bfloat16_t &operator/=(const bfloat16_t &other) noexcept { + *this = *this / other; + return *this; + } + + // Comparison operators (using conversion to float) + constexpr bool operator==(const bfloat16_t &other) const noexcept { + return float(*this) == float(other); + } + constexpr bool operator!=(const bfloat16_t &other) const noexcept { + return !(*this == other); + } + constexpr bool operator<(const bfloat16_t &other) const noexcept { + return float(*this) < float(other); + } + constexpr bool operator<=(const bfloat16_t &other) const noexcept { + return float(*this) <= float(other); + } + constexpr bool operator>(const bfloat16_t &other) const noexcept { + return float(*this) > float(other); + } + constexpr bool operator>=(const bfloat16_t &other) const noexcept { + return float(*this) >= float(other); + } +}; + +// Mark bfloat16_t as a trivial, standard-layout type so that xtensor can use +// it. +namespace std { +template <> +struct is_trivial : std::true_type {}; +template <> +struct is_standard_layout : std::true_type {}; +template <> +struct is_trivially_copyable : std::true_type {}; +} // namespace std + +// Math functions needed by xtensor for bfloat16_t +inline constexpr bfloat16_t round(bfloat16_t x) noexcept { + return bfloat16_t(std::round(float(x))); +} + +inline constexpr bfloat16_t ceil(bfloat16_t x) noexcept { + return bfloat16_t(std::ceil(float(x))); +} + +inline constexpr bfloat16_t floor(bfloat16_t x) noexcept { + return bfloat16_t(std::floor(float(x))); +} + +inline constexpr bfloat16_t trunc(bfloat16_t x) noexcept { + return bfloat16_t(std::trunc(float(x))); +} + +#endif // BFLOAT16_HPP + using namespace shortfin::array; namespace shortfin::python { @@ -191,6 +301,7 @@ struct ConvertFunctor { } switch (dtype) { SF_STORE_CASE(float16, half_float::half); + SF_STORE_CASE(bfloat16, bfloat16_t); SF_STORE_CASE(float32, float); SF_STORE_CASE(float64, double); SF_STORE_CASE(uint8, uint8_t); @@ -210,6 +321,7 @@ struct ConvertFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); SF_UNARY_THUNK_CASE(float64, double); SF_UNARY_THUNK_CASE(uint8, uint8_t); @@ -264,6 +376,7 @@ struct ConvertRoundFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -308,6 +421,7 @@ struct ConvertCeilFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -352,6 +466,7 @@ struct ConvertFloorFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -396,6 +511,7 @@ struct ConvertTruncFunctor { switch (input.dtype()) { SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t); SF_UNARY_THUNK_CASE(float32, float); default: throw std::invalid_argument(fmt::format( @@ -525,6 +641,11 @@ half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) { return static_cast(py::cast(py_value)); } +bfloat16_t ConvertPyToEltTy(py::handle py_value, bfloat16_t zero) { + // Python can't cast directly to half so first go to double. + return static_cast(py::cast(py_value)); +} + struct AddFunctor { template static auto Invoke(Lhs &&lhs, Rhs &&rhs) { @@ -610,6 +731,7 @@ device_array ElementwiseOperation(py::handle lhs, py::handle rhs, switch (dtype) { SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t); SF_UNARY_FUNCTION_CASE(float32, float); SF_UNARY_FUNCTION_CASE(float64, double); SF_UNARY_FUNCTION_CASE(uint8, uint8_t); @@ -661,6 +783,7 @@ void BindArrayHostOps(py::module_ &m) { switch (input.dtype()) { SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t); SF_UNARY_FUNCTION_CASE(float32, float); default: throw std::invalid_argument( @@ -690,6 +813,7 @@ void BindArrayHostOps(py::module_ &m) { switch (out.dtype()) { SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t); SF_UNARY_FUNCTION_CASE(float32, float); default: throw std::invalid_argument( diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py index 164dfb479..3e778d421 100644 --- a/shortfin/tests/api/array_ops_test.py +++ b/shortfin/tests/api/array_ops_test.py @@ -100,6 +100,7 @@ def test_argmax_axis0(device): @pytest.mark.parametrize( "dtype", [ + sfnp.bfloat16, sfnp.float16, sfnp.float32, ], @@ -114,6 +115,7 @@ def test_argmax_dtypes(device, dtype): @pytest.mark.parametrize( "dtype", [ + sfnp.bfloat16, sfnp.float16, sfnp.float32, ], @@ -138,6 +140,7 @@ def test_fill_randn_default_generator(device, dtype): @pytest.mark.parametrize( "dtype", [ + sfnp.bfloat16, sfnp.float16, sfnp.float32, ], @@ -180,6 +183,7 @@ def test_fill_randn_explicit_generator(device, dtype): sfnp.int16, sfnp.int32, sfnp.int64, + sfnp.bfloat16, sfnp.float16, sfnp.float32, sfnp.float64, @@ -208,12 +212,16 @@ def round_half_away_from_zero(n): @pytest.mark.parametrize( "dtype,sfnp_func,ref_round_func", [ + (sfnp.bfloat16, sfnp.round, round_half_away_from_zero), (sfnp.float16, sfnp.round, round_half_away_from_zero), (sfnp.float32, sfnp.round, round_half_away_from_zero), + (sfnp.bfloat16, sfnp.ceil, math.ceil), (sfnp.float16, sfnp.ceil, math.ceil), (sfnp.float32, sfnp.ceil, math.ceil), + (sfnp.bfloat16, sfnp.floor, math.floor), (sfnp.float16, sfnp.floor, math.floor), (sfnp.float32, sfnp.floor, math.floor), + (sfnp.bfloat16, sfnp.trunc, math.trunc), (sfnp.float16, sfnp.trunc, math.trunc), (sfnp.float32, sfnp.trunc, math.trunc), ], @@ -309,6 +317,8 @@ def test_elementwise_forms(device): @pytest.mark.parametrize( "lhs_dtype,rhs_dtype,promoted_dtype", [ + (sfnp.float32, sfnp.bfloat16, sfnp.float32), + (sfnp.bfloat16, sfnp.float32, sfnp.float32), (sfnp.float32, sfnp.float16, sfnp.float32), (sfnp.float16, sfnp.float32, sfnp.float32), (sfnp.float32, sfnp.float64, sfnp.float64), @@ -347,6 +357,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.add, 44.0), (sfnp.uint32, sfnp.add, 44.0), (sfnp.uint64, sfnp.add, 44.0), + (sfnp.bfloat16, sfnp.add, 44.0), (sfnp.float16, sfnp.add, 44.0), (sfnp.float32, sfnp.add, 44.0), (sfnp.float64, sfnp.add, 44.0), @@ -359,6 +370,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.divide, 21.0), (sfnp.uint32, sfnp.divide, 21.0), (sfnp.uint64, sfnp.divide, 21.0), + (sfnp.bfloat16, sfnp.divide, 21.0), (sfnp.float16, sfnp.divide, 21.0), (sfnp.float32, sfnp.divide, 21.0), (sfnp.float64, sfnp.divide, 21.0), @@ -371,6 +383,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.multiply, 84.0), (sfnp.uint32, sfnp.multiply, 84.0), (sfnp.uint64, sfnp.multiply, 84.0), + (sfnp.bfloat16, sfnp.multiply, 84.0), (sfnp.float16, sfnp.multiply, 84.0), (sfnp.float32, sfnp.multiply, 84.0), (sfnp.float64, sfnp.multiply, 84.0), @@ -383,6 +396,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.subtract, 40.0), (sfnp.uint32, sfnp.subtract, 40.0), (sfnp.uint64, sfnp.subtract, 40.0), + (sfnp.bfloat16, sfnp.subtract, 40.0), (sfnp.float16, sfnp.subtract, 40.0), (sfnp.float32, sfnp.subtract, 40.0), (sfnp.float64, sfnp.subtract, 40.0), @@ -418,6 +432,7 @@ def test_elementwise_array_correctness(device, dtype, op, check_value): sfnp.uint32, sfnp.uint64, sfnp.float32, + sfnp.bfloat16, sfnp.float16, sfnp.float32, sfnp.float64,