Skip to content

Commit

Permalink
Run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Feb 12, 2025
1 parent 2a03876 commit 6f547a0
Showing 1 changed file with 48 additions and 46 deletions.
94 changes: 48 additions & 46 deletions shortfin/python/array_host_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include "xtensor/xsort.hpp"
#include "xtl/xhalf_float.hpp"


#ifndef BFLOAT16_HPP
#define BFLOAT16_HPP

Expand All @@ -27,81 +26,84 @@ struct bfloat16_t {
constexpr bfloat16_t() noexcept : value(0) {}

explicit constexpr bfloat16_t(float f) noexcept {
uint32_t temp = std::bit_cast<uint32_t>(f);
value = static_cast<uint16_t>(temp >> 16);
uint32_t temp = std::bit_cast<uint32_t>(f);
value = static_cast<uint16_t>(temp >> 16);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T> &&
!std::is_same_v<T, float>>>
constexpr bfloat16_t(T value) noexcept
: bfloat16_t(static_cast<float>(value)) {}

constexpr operator float() const noexcept {
uint32_t temp = static_cast<uint32_t>(value) << 16;
return std::bit_cast<float>(temp);
uint32_t temp = static_cast<uint32_t>(value) << 16;
return std::bit_cast<float>(temp);
}

// Arithmetic operators (implemented via conversion to float)
constexpr bfloat16_t operator+(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) + float(other));
constexpr bfloat16_t operator+(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) + float(other));
}
constexpr bfloat16_t operator-(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) - float(other));
constexpr bfloat16_t operator-(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) - float(other));
}
constexpr bfloat16_t operator*(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) * float(other));
constexpr bfloat16_t operator*(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) * float(other));
}
constexpr bfloat16_t operator/(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) / float(other));
constexpr bfloat16_t operator/(const bfloat16_t &other) const noexcept {
return bfloat16_t(float(*this) / float(other));
}

constexpr bfloat16_t& operator+=(const bfloat16_t& other) noexcept {
*this = *this + other;
return *this;
constexpr bfloat16_t &operator+=(const bfloat16_t &other) noexcept {
*this = *this + other;
return *this;
}
constexpr bfloat16_t& operator-=(const bfloat16_t& other) noexcept {
*this = *this - other;
return *this;
constexpr bfloat16_t &operator-=(const bfloat16_t &other) noexcept {
*this = *this - other;
return *this;
}
constexpr bfloat16_t& operator*=(const bfloat16_t& other) noexcept {
*this = *this * other;
return *this;
constexpr bfloat16_t &operator*=(const bfloat16_t &other) noexcept {
*this = *this * other;
return *this;
}
constexpr bfloat16_t& operator/=(const bfloat16_t& other) noexcept {
*this = *this / other;
return *this;
constexpr bfloat16_t &operator/=(const bfloat16_t &other) noexcept {
*this = *this / other;
return *this;
}

// Comparison operators (using conversion to float)
constexpr bool operator==(const bfloat16_t& other) const noexcept {
return float(*this) == float(other);
constexpr bool operator==(const bfloat16_t &other) const noexcept {
return float(*this) == float(other);
}
constexpr bool operator!=(const bfloat16_t& other) const noexcept {
return !(*this == other);
constexpr bool operator!=(const bfloat16_t &other) const noexcept {
return !(*this == other);
}
constexpr bool operator<(const bfloat16_t& other) const noexcept {
return float(*this) < float(other);
constexpr bool operator<(const bfloat16_t &other) const noexcept {
return float(*this) < float(other);
}
constexpr bool operator<=(const bfloat16_t& other) const noexcept {
return float(*this) <= float(other);
constexpr bool operator<=(const bfloat16_t &other) const noexcept {
return float(*this) <= float(other);
}
constexpr bool operator>(const bfloat16_t& other) const noexcept {
return float(*this) > float(other);
constexpr bool operator>(const bfloat16_t &other) const noexcept {
return float(*this) > float(other);
}
constexpr bool operator>=(const bfloat16_t& other) const noexcept {
return float(*this) >= float(other);
constexpr bool operator>=(const bfloat16_t &other) const noexcept {
return float(*this) >= float(other);
}
};

// Mark bfloat16_t as a trivial, standard-layout type so that xtensor can use it.
// Mark bfloat16_t as a trivial, standard-layout type so that xtensor can use
// it.
namespace std {
template<> struct is_trivial<bfloat16_t> : std::true_type {};
template<> struct is_standard_layout<bfloat16_t> : std::true_type {};
template<> struct is_trivially_copyable<bfloat16_t> : std::true_type {};
}

#endif // BFLOAT16_HPP

template <>
struct is_trivial<bfloat16_t> : std::true_type {};
template <>
struct is_standard_layout<bfloat16_t> : std::true_type {};
template <>
struct is_trivially_copyable<bfloat16_t> : std::true_type {};
} // namespace std

#endif // BFLOAT16_HPP

using namespace shortfin::array;

Expand Down

0 comments on commit 6f547a0

Please sign in to comment.