diff --git a/tt-train/sources/ttml/core/ttnn_all_includes.hpp b/tt-train/sources/ttml/core/ttnn_all_includes.hpp index a675a76a3b0..b3b62da5295 100644 --- a/tt-train/sources/ttml/core/ttnn_all_includes.hpp +++ b/tt-train/sources/ttml/core/ttnn_all_includes.hpp @@ -72,4 +72,6 @@ #include // NOLINT #include // NOLINT #include // NOLINT + +#include "ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp" // NOLINT #pragma GCC diagnostic pop diff --git a/tt-train/sources/ttml/modules/rotary_embedding.cpp b/tt-train/sources/ttml/modules/rotary_embedding.cpp new file mode 100644 index 00000000000..aed5cf1db05 --- /dev/null +++ b/tt-train/sources/ttml/modules/rotary_embedding.cpp @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "modules/rotary_embedding.hpp" + +#include "autograd/auto_context.hpp" +#include "ops/rope_op.hpp" + +namespace ttml::modules { + +RotaryEmbedding::RotaryEmbedding(const ops::RotaryEmbeddingParams& rope_params) : m_rope_params(rope_params) { +} + +autograd::TensorPtr RotaryEmbedding::operator()(const autograd::TensorPtr& input) { + return ttml::ops::rope(input, m_rope_params); +} + +} // namespace ttml::modules diff --git a/tt-train/sources/ttml/modules/rotary_embedding.hpp b/tt-train/sources/ttml/modules/rotary_embedding.hpp new file mode 100644 index 00000000000..2f80bad437d --- /dev/null +++ b/tt-train/sources/ttml/modules/rotary_embedding.hpp @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "autograd/module_base.hpp" +#include "ops/rope_op.hpp" + +namespace ttml::modules { +class RotaryEmbedding : public autograd::ModuleBase { +private: + ops::RotaryEmbeddingParams m_rope_params; + +public: + explicit RotaryEmbedding(const ops::RotaryEmbeddingParams &rope_params); + [[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr &input) override; +}; + +} // namespace ttml::modules diff --git a/tt-train/sources/ttml/ops/rope_op.cpp b/tt-train/sources/ttml/ops/rope_op.cpp new file mode 100644 index 00000000000..a0cbc9773be --- /dev/null +++ b/tt-train/sources/ttml/ops/rope_op.cpp @@ -0,0 +1,167 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ops/rope_op.hpp" + +#include + +#include "autograd/auto_context.hpp" +#include "autograd/graph.hpp" +#include "autograd/graph_utils.hpp" +#include "autograd/tensor.hpp" +#include "core/tt_tensor_utils.hpp" +#include "core/ttnn_all_includes.hpp" +#include "core/xtensor_utils.hpp" +#include "ttnn/tensor/xtensor/xtensor_all_includes.hpp" +#include "ttnn_fixed/trivial_ttnn_ops.hpp" + +namespace ttml::ops { + +void validate_rope_input_and_params(const autograd::TensorPtr& input, const RotaryEmbeddingParams& params) { + if (input->get_rank() != 4U) { + throw std::runtime_error( + fmt::format("RoPE only supports rank-4 input tensors, but got rank {}.", input->get_rank())); + } + auto input_shape = input->get_shape(); + + auto input_seq_len = input_shape[-2]; + auto input_head_dim = input_shape[-1]; + + if (input_head_dim != params.head_dim) { + throw std::runtime_error(fmt::format( + "RoPE input tensor's head dimension ({}) must match the head dimension in the params ({})", + input_head_dim, + params.head_dim)); + } + + if (input_seq_len != params.sequence_length) { + throw std::runtime_error(fmt::format( + "RoPE input tensor's sequence length ({}) must match the sequence length in the params ({})", + input_seq_len, + params.sequence_length)); + } + + auto trans_mat_shape = params.trans_mat.get_logical_shape(); + auto trig_param_shapes = std::array{ + params.cos_cache.get_logical_shape(), + params.sin_cache.get_logical_shape(), + params.neg_cos_cache.get_logical_shape(), + params.neg_sin_cache.get_logical_shape()}; + + auto expected_trig_shape = ttnn::Shape{1U, 1U, input_seq_len, input_head_dim}; + if (!std::ranges::all_of( + trig_param_shapes, [&expected_trig_shape](auto shape) { return shape == expected_trig_shape; })) { + throw std::runtime_error(fmt::format( + "All trigonometric rotary embedding parameters must have shape [1, 1, {}, {}], but got shapes: " + "cos_cache: {}, sin_cache: {}, neg_cos_cache: {}, neg_sin_cache: {}", + input_seq_len, + input_head_dim, + params.cos_cache.get_logical_shape(), + params.sin_cache.get_logical_shape(), + params.neg_cos_cache.get_logical_shape(), + params.neg_sin_cache.get_logical_shape())); + } + + auto expected_trans_mat_shape = ttnn::Shape{1U, 1U, 32U, 32U}; + if (trans_mat_shape != expected_trans_mat_shape) { + throw std::runtime_error(fmt::format( + "RoPE trans_mat must be of shape {}, but has shape {}", expected_trans_mat_shape, trans_mat_shape)); + } +} + +// trans_mat, sin_cache, cos_cache are all precomputed and stored somewhere in +// the module hierarchy and passed to the operation. +autograd::TensorPtr rope(const autograd::TensorPtr& input, const RotaryEmbeddingParams& params) { + validate_rope_input_and_params(input, params); + + auto out_tensor = ttnn::experimental::rotary_embedding_llama( + input->get_value(), params.cos_cache, params.sin_cache, params.trans_mat); + auto out = autograd::create_tensor(out_tensor); + + // In the backward pass we rotate by -θ, so we need negated cos and sin + // caches. Note: we can just reuse trans_mat here since the data movement + // should be the same on the backward pass (we use the same trick to speed + // up the matmul, and the matrix used is specified by the cos/sin caches.) + autograd::GradFunction grad_fn = [input, params, out]() { + auto dL_dout = out->get_grad(); + + auto dL_dinput = ttnn::experimental::rotary_embedding_llama( + dL_dout, params.neg_cos_cache, params.neg_sin_cache, params.trans_mat); + input->add_grad(dL_dinput); + }; + + auto links = autograd::get_links(input); + out->set_node(autograd::ctx().add_backward_node(std::move(grad_fn), links)); + + return out; +} + +std::pair gen_freqs(uint32_t head_dim, uint32_t sequence_length, float theta = 10000.0F) { + int d = head_dim; + // compute freqs: 1.0 / (theta ** (2 * (i-1) / head_dim)) for i in [1, head_dim/2] + xt::xarray expt_data = xt::arange(0, d) / 2; + xt::xarray expt_xt = xt::cast(expt_data); + + expt_xt *= 2.0F / static_cast(head_dim); + xt::xarray theta_pow = xt::pow(theta, expt_xt); + + auto freqs = xt::ones_like(theta_pow) / theta_pow; + + xt::xarray seq_pos = xt::arange(sequence_length); + xt::xarray seq_pos_repeated_to_head = xt::repeat(seq_pos, head_dim, seq_pos.dimension() - 1U); + xt::xarray scales = seq_pos_repeated_to_head.reshape({sequence_length, static_cast(head_dim)}); + + xt::xarray scaled_freqs = scales * freqs; + + // take the scaled freqs mod 2π to satisfy ttnn inputs constraints for sin/cos + auto pi = static_cast(std::numbers::pi); + scaled_freqs = xt::fmod(scaled_freqs, 2.0F * pi); + scaled_freqs = scaled_freqs.reshape({1, 1, sequence_length, head_dim}); + + xt::xarray sin_freqs = xt::sin(scaled_freqs); + xt::xarray cos_freqs = xt::cos(scaled_freqs); + + auto* device = &autograd::ctx().get_device(); + return {core::from_xtensor(sin_freqs, device), core::from_xtensor(cos_freqs, device)}; +} + +ttnn::Tensor gen_trans_mat(int head_dim) { + xt::xarray trans_mat = xt::zeros({1, 1, head_dim, head_dim}); + for (int i = 0; i < head_dim; i += 2) { + trans_mat(0, 0, i, i + 1) = 1.0F; + } + for (int j = 1; j < head_dim; j += 2) { + trans_mat(0, 0, j, j - 1) = -1.0F; + } + + auto device = &autograd::ctx().get_device(); + return core::from_xtensor(trans_mat, device); +} + +RotaryEmbeddingParams build_rope_params(uint32_t sequence_length, uint32_t head_dim, float theta) { + if (head_dim % 32U != 0U) { + throw std::invalid_argument("RoPE head_dim must be divisible by 32"); + } + if (head_dim > 256U) { + throw std::invalid_argument("RoPE head_dim must be less than or equal to 256"); + } + if (head_dim <= 0U) { + throw std::invalid_argument("RoPE head_dim must be greater than 0"); + } + auto [sin_freqs, cos_freqs] = gen_freqs(head_dim, sequence_length, theta); + auto trans_mat = gen_trans_mat(head_dim); + + return { + .cos_cache = cos_freqs, + .sin_cache = sin_freqs, + .neg_cos_cache = cos_freqs, // cos(θ) = cos(-θ): symmetry over x-axis + .neg_sin_cache = ttnn::neg(sin_freqs), // sin(-θ) = -sin(θ) + .trans_mat = trans_mat, + + .sequence_length = sequence_length, + .head_dim = head_dim, + }; +} + +} // namespace ttml::ops diff --git a/tt-train/sources/ttml/ops/rope_op.hpp b/tt-train/sources/ttml/ops/rope_op.hpp new file mode 100644 index 00000000000..82d039d6a9e --- /dev/null +++ b/tt-train/sources/ttml/ops/rope_op.hpp @@ -0,0 +1,29 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "autograd/tensor.hpp" + +namespace ttml::ops { + +struct RotaryEmbeddingParams { + ttnn::Tensor cos_cache; + ttnn::Tensor sin_cache; + ttnn::Tensor neg_cos_cache; + ttnn::Tensor neg_sin_cache; + ttnn::Tensor trans_mat; + + uint32_t sequence_length = 0; + uint32_t head_dim = 0; +}; + +autograd::TensorPtr rope(const autograd::TensorPtr& input, const RotaryEmbeddingParams& rope_params); + +RotaryEmbeddingParams build_rope_params(uint32_t sequence_length, uint32_t head_dim, float theta = 10000.0F); +// Throws an exception if the input is bad, parameters are bad, or the two are +// incompatible with one another. +void validate_rope_input_and_params(const autograd::TensorPtr& input, const RotaryEmbeddingParams& rope_params); + +} // namespace ttml::ops diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp index cf6aa1a3473..de37267292a 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp @@ -76,4 +76,12 @@ tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_ return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise()); } +tt::tt_metal::Tensor to_l1_interleaved(const tt::tt_metal::Tensor& t) { + return ttnn::to_memory_config(t, ttnn::L1_MEMORY_CONFIG); +} + +tt::tt_metal::Tensor to_dram_interleaved(const tt::tt_metal::Tensor& t) { + return ttnn::to_memory_config(t, ttnn::DRAM_MEMORY_CONFIG); +} + } // namespace ttml::ttnn_fixed diff --git a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp index c8a62d981bc..77b25884d53 100644 --- a/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp +++ b/tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp @@ -20,4 +20,8 @@ tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim); + +tt::tt_metal::Tensor to_l1_interleaved(const tt::tt_metal::Tensor& t); +tt::tt_metal::Tensor to_dram_interleaved(const tt::tt_metal::Tensor& t); + } // namespace ttml::ttnn_fixed diff --git a/tt-train/tests/ops/rope_test.cpp b/tt-train/tests/ops/rope_test.cpp new file mode 100644 index 00000000000..387e74ad6a7 --- /dev/null +++ b/tt-train/tests/ops/rope_test.cpp @@ -0,0 +1,531 @@ +// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 +#include + +#include + +#include "autograd/tensor.hpp" +#include "core/tt_tensor_utils.hpp" +#include "core/xtensor_utils.hpp" +#include "modules/positional_embeddings.hpp" +#include "modules/rotary_embedding.hpp" +#include "ops/losses.hpp" + +namespace ttml::modules::tests { + +class RoPETest : public ::testing::Test { +protected: + void SetUp() override { + ttml::autograd::ctx().open_device(); + } + + void TearDown() override { + ttml::autograd::ctx().close_device(); + } +}; + +TEST_F(RoPETest, GeneratedParamsOk) { + xt::xarray expected_cos = { + {{{1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {0.54030F, 0.54030F, 0.84601F, 0.84601F, 0.95042F, 0.95042F, 0.98423F, 0.98423F, 0.99500F, 0.99500F, 0.99842F, + 0.99842F, 0.99950F, 0.99950F, 0.99984F, 0.99984F, 0.99995F, 0.99995F, 0.99998F, 0.99998F, 0.99999F, 0.99999F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.41615F, -0.41615F, 0.43146F, 0.43146F, 0.80658F, 0.80658F, 0.93742F, 0.93742F, + 0.98007F, 0.98007F, 0.99368F, 0.99368F, 0.99800F, 0.99800F, 0.99937F, 0.99937F, + 0.99980F, 0.99980F, 0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F, 0.99999F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.98999F, -0.98999F, -0.11597F, -0.11597F, 0.58275F, 0.58275F, 0.86104F, 0.86104F, + 0.95534F, 0.95534F, 0.98580F, 0.98580F, 0.99550F, 0.99550F, 0.99858F, 0.99858F, + 0.99955F, 0.99955F, 0.99986F, 0.99986F, 0.99995F, 0.99995F, 0.99999F, 0.99999F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.65364F, -0.65364F, -0.62768F, -0.62768F, 0.30114F, 0.30114F, 0.75751F, 0.75751F, + 0.92106F, 0.92106F, 0.97481F, 0.97481F, 0.99201F, 0.99201F, 0.99747F, 0.99747F, + 0.99920F, 0.99920F, 0.99975F, 0.99975F, 0.99992F, 0.99992F, 0.99997F, 0.99997F, + 0.99999F, 0.99999F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {0.28366F, 0.28366F, -0.94608F, -0.94608F, -0.01034F, -0.01034F, 0.63008F, 0.63008F, + 0.87758F, 0.87758F, 0.96073F, 0.96073F, 0.98753F, 0.98753F, 0.99605F, 0.99605F, + 0.99875F, 0.99875F, 0.99960F, 0.99960F, 0.99988F, 0.99988F, 0.99996F, 0.99996F, + 0.99999F, 0.99999F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {0.96017F, 0.96017F, -0.97310F, -0.97310F, -0.32080F, -0.32080F, 0.48278F, 0.48278F, + 0.82534F, 0.82534F, 0.94362F, 0.94362F, 0.98205F, 0.98205F, 0.99431F, 0.99431F, + 0.99820F, 0.99820F, 0.99943F, 0.99943F, 0.99982F, 0.99982F, 0.99994F, 0.99994F, + 0.99998F, 0.99998F, 0.99999F, 0.99999F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {0.75390F, 0.75390F, -0.70043F, -0.70043F, -0.59944F, -0.59944F, 0.32026F, 0.32026F, + 0.76484F, 0.76484F, 0.92352F, 0.92352F, 0.97560F, 0.97560F, 0.99226F, 0.99226F, + 0.99755F, 0.99755F, 0.99923F, 0.99923F, 0.99976F, 0.99976F, 0.99992F, 0.99992F, + 0.99998F, 0.99998F, 0.99999F, 0.99999F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.14550F, -0.14550F, -0.21204F, -0.21204F, -0.81863F, -0.81863F, 0.14763F, 0.14763F, + 0.69671F, 0.69671F, 0.90050F, 0.90050F, 0.96817F, 0.96817F, 0.98990F, 0.98990F, + 0.99680F, 0.99680F, 0.99899F, 0.99899F, 0.99968F, 0.99968F, 0.99990F, 0.99990F, + 0.99997F, 0.99997F, 0.99999F, 0.99999F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.91113F, -0.91113F, 0.34166F, 0.34166F, -0.95664F, -0.95664F, -0.02965F, -0.02965F, + 0.62161F, 0.62161F, 0.87464F, 0.87464F, 0.95977F, 0.95977F, 0.98722F, 0.98722F, + 0.99595F, 0.99595F, 0.99872F, 0.99872F, 0.99960F, 0.99960F, 0.99987F, 0.99987F, + 0.99996F, 0.99996F, 0.99999F, 0.99999F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.83907F, -0.83907F, 0.79013F, 0.79013F, -0.99979F, -0.99979F, -0.20600F, -0.20600F, + 0.54030F, 0.54030F, 0.84601F, 0.84601F, 0.95042F, 0.95042F, 0.98423F, 0.98423F, + 0.99500F, 0.99500F, 0.99842F, 0.99842F, 0.99950F, 0.99950F, 0.99984F, 0.99984F, + 0.99995F, 0.99995F, 0.99998F, 0.99998F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {0.00443F, 0.00443F, 0.99526F, 0.99526F, -0.94378F, -0.94378F, -0.37585F, -0.37585F, + 0.45360F, 0.45360F, 0.81471F, 0.81471F, 0.94011F, 0.94011F, 0.98093F, 0.98093F, + 0.99396F, 0.99396F, 0.99809F, 0.99809F, 0.99940F, 0.99940F, 0.99981F, 0.99981F, + 0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {0.84385F, 0.84385F, 0.89386F, 0.89386F, -0.79418F, -0.79418F, -0.53384F, -0.53384F, + 0.36236F, 0.36236F, 0.78083F, 0.78083F, 0.92886F, 0.92886F, 0.97732F, 0.97732F, + 0.99281F, 0.99281F, 0.99772F, 0.99772F, 0.99928F, 0.99928F, 0.99977F, 0.99977F, + 0.99993F, 0.99993F, 0.99998F, 0.99998F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {0.90745F, 0.90745F, 0.51717F, 0.51717F, -0.56582F, -0.56582F, -0.67500F, -0.67500F, + 0.26750F, 0.26750F, 0.74448F, 0.74448F, 0.91668F, 0.91668F, 0.97340F, 0.97340F, + 0.99156F, 0.99156F, 0.99733F, 0.99733F, 0.99916F, 0.99916F, 0.99973F, 0.99973F, + 0.99992F, 0.99992F, 0.99997F, 0.99997F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {0.13674F, 0.13674F, -0.01880F, -0.01880F, -0.28135F, -0.28135F, -0.79487F, -0.79487F, + 0.16997F, 0.16997F, 0.70578F, 0.70578F, 0.90359F, 0.90359F, 0.96917F, 0.96917F, + 0.99022F, 0.99022F, 0.99690F, 0.99690F, 0.99902F, 0.99902F, 0.99969F, 0.99969F, + 0.99990F, 0.99990F, 0.99997F, 0.99997F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {-0.75969F, -0.75969F, -0.54898F, -0.54898F, 0.03102F, 0.03102F, -0.88967F, -0.88967F, + 0.07074F, 0.07074F, 0.66484F, 0.66484F, 0.88959F, 0.88959F, 0.96463F, 0.96463F, + 0.98877F, 0.98877F, 0.99644F, 0.99644F, 0.99888F, 0.99888F, 0.99964F, 0.99964F, + 0.99989F, 0.99989F, 0.99996F, 0.99996F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {-0.95766F, -0.95766F, -0.91008F, -0.91008F, 0.34032F, 0.34032F, -0.95641F, -0.95641F, + -0.02920F, -0.02920F, 0.62181F, 0.62181F, 0.87471F, 0.87471F, 0.95980F, 0.95980F, + 0.98723F, 0.98723F, 0.99595F, 0.99595F, 0.99872F, 0.99872F, 0.99960F, 0.99960F, + 0.99987F, 0.99987F, 0.99996F, 0.99996F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {-0.27516F, -0.27516F, -0.99090F, -0.99090F, 0.61586F, 0.61586F, -0.99298F, -0.99298F, + -0.12884F, -0.12884F, 0.57681F, 0.57681F, 0.85895F, 0.85895F, 0.95465F, 0.95465F, + 0.98558F, 0.98558F, 0.99543F, 0.99543F, 0.99856F, 0.99856F, 0.99954F, 0.99954F, + 0.99986F, 0.99986F, 0.99995F, 0.99995F, 0.99999F, 0.99999F, 1.00000F, 1.00000F}, + {0.66032F, 0.66032F, -0.76654F, -0.76654F, 0.83034F, 0.83034F, -0.99824F, -0.99824F, + -0.22720F, -0.22720F, 0.52998F, 0.52998F, 0.84233F, 0.84233F, 0.94921F, 0.94921F, + 0.98384F, 0.98384F, 0.99488F, 0.99488F, 0.99838F, 0.99838F, 0.99949F, 0.99949F, + 0.99984F, 0.99984F, 0.99995F, 0.99995F, 0.99998F, 0.99998F, 0.99999F, 0.99999F}, + {0.98870F, 0.98870F, -0.30610F, -0.30610F, 0.96246F, 0.96246F, -0.97201F, -0.97201F, + -0.32329F, -0.32329F, 0.48148F, 0.48148F, 0.82487F, 0.82487F, 0.94346F, 0.94346F, + 0.98200F, 0.98200F, 0.99430F, 0.99430F, 0.99820F, 0.99820F, 0.99943F, 0.99943F, + 0.99982F, 0.99982F, 0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F, 0.99999F}, + {0.40808F, 0.40808F, 0.24862F, 0.24862F, 0.99914F, 0.99914F, -0.91513F, -0.91513F, + -0.41615F, -0.41615F, 0.43146F, 0.43146F, 0.80658F, 0.80658F, 0.93742F, 0.93742F, + 0.98007F, 0.98007F, 0.99368F, 0.99368F, 0.99800F, 0.99800F, 0.99937F, 0.99937F, + 0.99980F, 0.99980F, 0.99994F, 0.99994F, 0.99998F, 0.99998F, 0.99999F, 0.99999F}, + {-0.54773F, -0.54773F, 0.72676F, 0.72676F, 0.93674F, 0.93674F, -0.82938F, -0.82938F, + -0.50485F, -0.50485F, 0.38008F, 0.38008F, 0.78749F, 0.78749F, 0.93108F, 0.93108F, + 0.97803F, 0.97803F, 0.99304F, 0.99304F, 0.99780F, 0.99780F, 0.99930F, 0.99930F, + 0.99978F, 0.99978F, 0.99993F, 0.99993F, 0.99998F, 0.99998F, 0.99999F, 0.99999F}, + {-0.99996F, -0.99996F, 0.98107F, 0.98107F, 0.78144F, 0.78144F, -0.71748F, -0.71748F, + -0.58850F, -0.58850F, 0.32749F, 0.32749F, 0.76760F, 0.76760F, 0.92444F, 0.92444F, + 0.97590F, 0.97590F, 0.99236F, 0.99236F, 0.99758F, 0.99758F, 0.99923F, 0.99923F, + 0.99976F, 0.99976F, 0.99992F, 0.99992F, 0.99998F, 0.99998F, 0.99999F, 0.99999F}, + {-0.53283F, -0.53283F, 0.93324F, 0.93324F, 0.54865F, 0.54865F, -0.58294F, -0.58294F, + -0.66628F, -0.66628F, 0.27387F, 0.27387F, 0.74696F, 0.74696F, 0.91752F, 0.91752F, + 0.97367F, 0.97367F, 0.99165F, 0.99165F, 0.99736F, 0.99736F, 0.99916F, 0.99916F, + 0.99974F, 0.99974F, 0.99992F, 0.99992F, 0.99997F, 0.99997F, 0.99999F, 0.99999F}, + {0.42418F, 0.42418F, 0.59798F, 0.59798F, 0.26144F, 0.26144F, -0.43002F, -0.43002F, + -0.73739F, -0.73739F, 0.21938F, 0.21938F, 0.72556F, 0.72556F, 0.91030F, 0.91030F, + 0.97134F, 0.97134F, 0.99091F, 0.99091F, 0.99712F, 0.99712F, 0.99909F, 0.99909F, + 0.99971F, 0.99971F, 0.99991F, 0.99991F, 0.99997F, 0.99997F, 0.99999F, 0.99999F}, + {0.99120F, 0.99120F, 0.07855F, 0.07855F, -0.05169F, -0.05169F, -0.26354F, -0.26354F, + -0.80114F, -0.80114F, 0.16420F, 0.16420F, 0.70344F, 0.70344F, 0.90280F, 0.90280F, + 0.96891F, 0.96891F, 0.99013F, 0.99013F, 0.99688F, 0.99688F, 0.99901F, 0.99901F, + 0.99969F, 0.99969F, 0.99990F, 0.99990F, 0.99997F, 0.99997F, 0.99999F, 0.99999F}, + {0.64692F, 0.64692F, -0.46506F, -0.46506F, -0.35969F, -0.35969F, -0.08875F, -0.08875F, + -0.85689F, -0.85689F, 0.10849F, 0.10849F, 0.68062F, 0.68062F, 0.89501F, 0.89501F, + 0.96639F, 0.96639F, 0.98933F, 0.98933F, 0.99662F, 0.99662F, 0.99893F, 0.99893F, + 0.99966F, 0.99966F, 0.99989F, 0.99989F, 0.99997F, 0.99997F, 0.99999F, 0.99999F}, + {-0.29214F, -0.29214F, -0.86545F, -0.86545F, -0.63203F, -0.63203F, 0.08885F, 0.08885F, + -0.90407F, -0.90407F, 0.05245F, 0.05245F, 0.65711F, 0.65711F, 0.88693F, 0.88693F, + 0.96377F, 0.96377F, 0.98850F, 0.98850F, 0.99636F, 0.99636F, 0.99885F, 0.99885F, + 0.99964F, 0.99964F, 0.99988F, 0.99988F, 0.99996F, 0.99996F, 0.99999F, 0.99999F}, + {-0.96261F, -0.96261F, -0.99929F, -0.99929F, -0.84168F, -0.84168F, 0.26364F, 0.26364F, + -0.94222F, -0.94222F, -0.00376F, -0.00376F, 0.63295F, 0.63295F, 0.87858F, 0.87858F, + 0.96106F, 0.96106F, 0.98763F, 0.98763F, 0.99608F, 0.99608F, 0.99876F, 0.99876F, + 0.99961F, 0.99961F, 0.99988F, 0.99988F, 0.99996F, 0.99996F, 0.99999F, 0.99999F}, + {-0.74806F, -0.74806F, -0.82537F, -0.82537F, -0.96787F, -0.96787F, 0.43012F, 0.43012F, + -0.97096F, -0.97096F, -0.05996F, -0.05996F, 0.60816F, 0.60816F, 0.86995F, 0.86995F, + 0.95824F, 0.95824F, 0.98673F, 0.98673F, 0.99580F, 0.99580F, 0.99867F, 0.99867F, + 0.99958F, 0.99958F, 0.99987F, 0.99987F, 0.99996F, 0.99996F, 0.99999F, 0.99999F}, + {0.15425F, 0.15425F, -0.39725F, -0.39725F, -0.99808F, -0.99808F, 0.58303F, 0.58303F, + -0.98999F, -0.98999F, -0.11597F, -0.11597F, 0.58275F, 0.58275F, 0.86104F, 0.86104F, + 0.95534F, 0.95534F, 0.98580F, 0.98580F, 0.99550F, 0.99550F, 0.99858F, 0.99858F, + 0.99955F, 0.99955F, 0.99986F, 0.99986F, 0.99995F, 0.99995F, 0.99999F, 0.99999F}, + {0.91474F, 0.91474F, 0.15322F, 0.15322F, -0.92930F, -0.92930F, 0.71755F, 0.71755F, + -0.99914F, -0.99914F, -0.17161F, -0.17161F, 0.55677F, 0.55677F, 0.85186F, 0.85186F, + 0.95233F, 0.95233F, 0.98484F, 0.98484F, 0.99520F, 0.99520F, 0.99848F, 0.99848F, + 0.99952F, 0.99952F, 0.99985F, 0.99985F, 0.99995F, 0.99995F, 0.99998F, 0.99998F}}}}; + xt::xarray expected_sin = { + {{{0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.84147F, 0.84147F, 0.53317F, 0.53317F, 0.31098F, 0.31098F, 0.17689F, 0.17689F, 0.09983F, 0.09983F, 0.05620F, + 0.05620F, 0.03162F, 0.03162F, 0.01778F, 0.01778F, 0.01000F, 0.01000F, 0.00562F, 0.00562F, 0.00316F, 0.00316F, + 0.00178F, 0.00178F, 0.00100F, 0.00100F, 0.00056F, 0.00056F, 0.00032F, 0.00032F, 0.00018F, 0.00018F}, + {0.90930F, 0.90930F, 0.90213F, 0.90213F, 0.59113F, 0.59113F, 0.34821F, 0.34821F, 0.19867F, 0.19867F, 0.11223F, + 0.11223F, 0.06320F, 0.06320F, 0.03556F, 0.03556F, 0.02000F, 0.02000F, 0.01125F, 0.01125F, 0.00632F, 0.00632F, + 0.00356F, 0.00356F, 0.00200F, 0.00200F, 0.00112F, 0.00112F, 0.00063F, 0.00063F, 0.00036F, 0.00036F}, + {0.14112F, 0.14112F, 0.99325F, 0.99325F, 0.81265F, 0.81265F, 0.50854F, 0.50854F, 0.29552F, 0.29552F, 0.16790F, + 0.16790F, 0.09473F, 0.09473F, 0.05332F, 0.05332F, 0.03000F, 0.03000F, 0.01687F, 0.01687F, 0.00949F, 0.00949F, + 0.00533F, 0.00533F, 0.00300F, 0.00300F, 0.00169F, 0.00169F, 0.00095F, 0.00095F, 0.00053F, 0.00053F}, + {-0.75680F, -0.75680F, 0.77847F, 0.77847F, 0.95358F, 0.95358F, 0.65283F, 0.65283F, + 0.38942F, 0.38942F, 0.22304F, 0.22304F, 0.12615F, 0.12615F, 0.07107F, 0.07107F, + 0.03999F, 0.03999F, 0.02249F, 0.02249F, 0.01265F, 0.01265F, 0.00711F, 0.00711F, + 0.00400F, 0.00400F, 0.00225F, 0.00225F, 0.00126F, 0.00126F, 0.00071F, 0.00071F}, + {-0.95892F, -0.95892F, 0.32394F, 0.32394F, 0.99995F, 0.99995F, 0.77653F, 0.77653F, + 0.47943F, 0.47943F, 0.27748F, 0.27748F, 0.15746F, 0.15746F, 0.08880F, 0.08880F, + 0.04998F, 0.04998F, 0.02811F, 0.02811F, 0.01581F, 0.01581F, 0.00889F, 0.00889F, + 0.00500F, 0.00500F, 0.00281F, 0.00281F, 0.00158F, 0.00158F, 0.00089F, 0.00089F}, + {-0.27942F, -0.27942F, -0.23037F, -0.23037F, 0.94715F, 0.94715F, 0.87574F, 0.87574F, + 0.56464F, 0.56464F, 0.33104F, 0.33104F, 0.18860F, 0.18860F, 0.10649F, 0.10649F, + 0.05996F, 0.05996F, 0.03373F, 0.03373F, 0.01897F, 0.01897F, 0.01067F, 0.01067F, + 0.00600F, 0.00600F, 0.00337F, 0.00337F, 0.00190F, 0.00190F, 0.00107F, 0.00107F}, + {0.65699F, 0.65699F, -0.71372F, -0.71372F, 0.80042F, 0.80042F, 0.94733F, 0.94733F, + 0.64422F, 0.64422F, 0.38355F, 0.38355F, 0.21956F, 0.21956F, 0.12416F, 0.12416F, + 0.06994F, 0.06994F, 0.03935F, 0.03935F, 0.02213F, 0.02213F, 0.01245F, 0.01245F, + 0.00700F, 0.00700F, 0.00394F, 0.00394F, 0.00221F, 0.00221F, 0.00124F, 0.00124F}, + {0.98936F, 0.98936F, -0.97726F, -0.97726F, 0.57432F, 0.57432F, 0.98904F, 0.98904F, + 0.71736F, 0.71736F, 0.43485F, 0.43485F, 0.25029F, 0.25029F, 0.14178F, 0.14178F, + 0.07991F, 0.07991F, 0.04497F, 0.04497F, 0.02530F, 0.02530F, 0.01423F, 0.01423F, + 0.00800F, 0.00800F, 0.00450F, 0.00450F, 0.00253F, 0.00253F, 0.00142F, 0.00142F}, + {0.41212F, 0.41212F, -0.93982F, -0.93982F, 0.29126F, 0.29126F, 0.99956F, 0.99956F, + 0.78333F, 0.78333F, 0.48478F, 0.48478F, 0.28078F, 0.28078F, 0.15936F, 0.15936F, + 0.08988F, 0.08988F, 0.05059F, 0.05059F, 0.02846F, 0.02846F, 0.01600F, 0.01600F, + 0.00900F, 0.00900F, 0.00506F, 0.00506F, 0.00285F, 0.00285F, 0.00160F, 0.00160F}, + {-0.54402F, -0.54402F, -0.61294F, -0.61294F, -0.02068F, -0.02068F, 0.97855F, 0.97855F, + 0.84147F, 0.84147F, 0.53317F, 0.53317F, 0.31098F, 0.31098F, 0.17689F, 0.17689F, + 0.09983F, 0.09983F, 0.05620F, 0.05620F, 0.03162F, 0.03162F, 0.01778F, 0.01778F, + 0.01000F, 0.01000F, 0.00562F, 0.00562F, 0.00316F, 0.00316F, 0.00178F, 0.00178F}, + {-0.99999F, -0.99999F, -0.09728F, -0.09728F, -0.33057F, -0.33057F, 0.92668F, 0.92668F, + 0.89121F, 0.89121F, 0.57988F, 0.57988F, 0.34088F, 0.34088F, 0.19437F, 0.19437F, + 0.10978F, 0.10978F, 0.06182F, 0.06182F, 0.03478F, 0.03478F, 0.01956F, 0.01956F, + 0.01100F, 0.01100F, 0.00619F, 0.00619F, 0.00348F, 0.00348F, 0.00196F, 0.00196F}, + {-0.53657F, -0.53657F, 0.44834F, 0.44834F, -0.60768F, -0.60768F, 0.84558F, 0.84558F, + 0.93204F, 0.93204F, 0.62475F, 0.62475F, 0.37043F, 0.37043F, 0.21178F, 0.21178F, + 0.11971F, 0.11971F, 0.06743F, 0.06743F, 0.03794F, 0.03794F, 0.02134F, 0.02134F, + 0.01200F, 0.01200F, 0.00675F, 0.00675F, 0.00379F, 0.00379F, 0.00213F, 0.00213F}, + {0.42017F, 0.42017F, 0.85588F, 0.85588F, -0.82453F, -0.82453F, 0.73782F, 0.73782F, + 0.96356F, 0.96356F, 0.66765F, 0.66765F, 0.39961F, 0.39961F, 0.22912F, 0.22912F, + 0.12963F, 0.12963F, 0.07304F, 0.07304F, 0.04110F, 0.04110F, 0.02312F, 0.02312F, + 0.01300F, 0.01300F, 0.00731F, 0.00731F, 0.00411F, 0.00411F, 0.00231F, 0.00231F}, + {0.99061F, 0.99061F, 0.99982F, 0.99982F, -0.95961F, -0.95961F, 0.60678F, 0.60678F, + 0.98545F, 0.98545F, 0.70843F, 0.70843F, 0.42840F, 0.42840F, 0.24640F, 0.24640F, + 0.13954F, 0.13954F, 0.07865F, 0.07865F, 0.04426F, 0.04426F, 0.02489F, 0.02489F, + 0.01400F, 0.01400F, 0.00787F, 0.00787F, 0.00443F, 0.00443F, 0.00249F, 0.00249F}, + {0.65029F, 0.65029F, 0.83584F, 0.83584F, -0.99952F, -0.99952F, 0.45660F, 0.45660F, + 0.99749F, 0.99749F, 0.74698F, 0.74698F, 0.45675F, 0.45675F, 0.26359F, 0.26359F, + 0.14944F, 0.14944F, 0.08425F, 0.08425F, 0.04742F, 0.04742F, 0.02667F, 0.02667F, + 0.01500F, 0.01500F, 0.00844F, 0.00844F, 0.00474F, 0.00474F, 0.00267F, 0.00267F}, + {-0.28790F, -0.28790F, 0.41443F, 0.41443F, -0.94031F, -0.94031F, 0.29203F, 0.29203F, + 0.99957F, 0.99957F, 0.78317F, 0.78317F, 0.48465F, 0.48465F, 0.28070F, 0.28070F, + 0.15932F, 0.15932F, 0.08985F, 0.08985F, 0.05057F, 0.05057F, 0.02845F, 0.02845F, + 0.01600F, 0.01600F, 0.00900F, 0.00900F, 0.00506F, 0.00506F, 0.00285F, 0.00285F}, + {-0.96140F, -0.96140F, -0.13462F, -0.13462F, -0.78785F, -0.78785F, 0.11824F, 0.11824F, + 0.99166F, 0.99166F, 0.81688F, 0.81688F, 0.51207F, 0.51207F, 0.29772F, 0.29772F, + 0.16918F, 0.16918F, 0.09545F, 0.09545F, 0.05373F, 0.05373F, 0.03023F, 0.03023F, + 0.01700F, 0.01700F, 0.00956F, 0.00956F, 0.00538F, 0.00538F, 0.00302F, 0.00302F}, + {-0.75099F, -0.75099F, -0.64220F, -0.64220F, -0.55726F, -0.55726F, -0.05928F, -0.05928F, + 0.97385F, 0.97385F, 0.84801F, 0.84801F, 0.53897F, 0.53897F, 0.31465F, 0.31465F, + 0.17903F, 0.17903F, 0.10105F, 0.10105F, 0.05689F, 0.05689F, 0.03200F, 0.03200F, + 0.01800F, 0.01800F, 0.01012F, 0.01012F, 0.00569F, 0.00569F, 0.00320F, 0.00320F}, + {0.14988F, 0.14988F, -0.95200F, -0.95200F, -0.27141F, -0.27141F, -0.23492F, -0.23492F, + 0.94630F, 0.94630F, 0.87645F, 0.87645F, 0.56533F, 0.56533F, 0.33148F, 0.33148F, + 0.18886F, 0.18886F, 0.10664F, 0.10664F, 0.06005F, 0.06005F, 0.03378F, 0.03378F, + 0.01900F, 0.01900F, 0.01068F, 0.01068F, 0.00601F, 0.00601F, 0.00338F, 0.00338F}, + {0.91295F, 0.91295F, -0.96860F, -0.96860F, 0.04136F, 0.04136F, -0.40316F, -0.40316F, + 0.90930F, 0.90930F, 0.90213F, 0.90213F, 0.59113F, 0.59113F, 0.34821F, 0.34821F, + 0.19867F, 0.19867F, 0.11223F, 0.11223F, 0.06320F, 0.06320F, 0.03556F, 0.03556F, + 0.02000F, 0.02000F, 0.01125F, 0.01125F, 0.00632F, 0.00632F, 0.00356F, 0.00356F}, + {0.83666F, 0.83666F, -0.68689F, -0.68689F, 0.35002F, 0.35002F, -0.55868F, -0.55868F, + 0.86321F, 0.86321F, 0.92495F, 0.92495F, 0.61633F, 0.61633F, 0.36482F, 0.36482F, + 0.20846F, 0.20846F, 0.11782F, 0.11782F, 0.06636F, 0.06636F, 0.03734F, 0.03734F, + 0.02100F, 0.02100F, 0.01181F, 0.01181F, 0.00664F, 0.00664F, 0.00373F, 0.00373F}, + {-0.00885F, -0.00885F, -0.19363F, -0.19363F, 0.62398F, 0.62398F, -0.69658F, -0.69658F, + 0.80850F, 0.80850F, 0.94485F, 0.94485F, 0.64092F, 0.64092F, 0.38132F, 0.38132F, + 0.21823F, 0.21823F, 0.12340F, 0.12340F, 0.06951F, 0.06951F, 0.03911F, 0.03911F, + 0.02200F, 0.02200F, 0.01237F, 0.01237F, 0.00696F, 0.00696F, 0.00391F, 0.00391F}, + {-0.84622F, -0.84622F, 0.35926F, 0.35926F, 0.83606F, 0.83606F, -0.81251F, -0.81251F, + 0.74571F, 0.74571F, 0.96177F, 0.96177F, 0.66487F, 0.66487F, 0.39770F, 0.39770F, + 0.22798F, 0.22798F, 0.12898F, 0.12898F, 0.07267F, 0.07267F, 0.04089F, 0.04089F, + 0.02300F, 0.02300F, 0.01293F, 0.01293F, 0.00727F, 0.00727F, 0.00409F, 0.00409F}, + {-0.90558F, -0.90558F, 0.80151F, 0.80151F, 0.96522F, 0.96522F, -0.90282F, -0.90282F, + 0.67546F, 0.67546F, 0.97564F, 0.97564F, 0.68816F, 0.68816F, 0.41395F, 0.41395F, + 0.23770F, 0.23770F, 0.13455F, 0.13455F, 0.07582F, 0.07582F, 0.04267F, 0.04267F, + 0.02400F, 0.02400F, 0.01350F, 0.01350F, 0.00759F, 0.00759F, 0.00427F, 0.00427F}, + {-0.13235F, -0.13235F, 0.99691F, 0.99691F, 0.99866F, 0.99866F, -0.96465F, -0.96465F, + 0.59847F, 0.59847F, 0.98643F, 0.98643F, 0.71075F, 0.71075F, 0.43007F, 0.43007F, + 0.24740F, 0.24740F, 0.14012F, 0.14012F, 0.07897F, 0.07897F, 0.04444F, 0.04444F, + 0.02500F, 0.02500F, 0.01406F, 0.01406F, 0.00791F, 0.00791F, 0.00445F, 0.00445F}, + {0.76256F, 0.76256F, 0.88528F, 0.88528F, 0.93307F, 0.93307F, -0.99605F, -0.99605F, + 0.51550F, 0.51550F, 0.99410F, 0.99410F, 0.73264F, 0.73264F, 0.44605F, 0.44605F, + 0.25708F, 0.25708F, 0.14569F, 0.14569F, 0.08213F, 0.08213F, 0.04622F, 0.04622F, + 0.02600F, 0.02600F, 0.01462F, 0.01462F, 0.00822F, 0.00822F, 0.00462F, 0.00462F}, + {0.95638F, 0.95638F, 0.50099F, 0.50099F, 0.77495F, 0.77495F, -0.99605F, -0.99605F, + 0.42738F, 0.42738F, 0.99862F, 0.99862F, 0.75379F, 0.75379F, 0.46190F, 0.46190F, + 0.26673F, 0.26673F, 0.15125F, 0.15125F, 0.08528F, 0.08528F, 0.04800F, 0.04800F, + 0.02700F, 0.02700F, 0.01518F, 0.01518F, 0.00854F, 0.00854F, 0.00480F, 0.00480F}, + {0.27091F, 0.27091F, -0.03759F, -0.03759F, 0.53997F, 0.53997F, -0.96462F, -0.96462F, + 0.33499F, 0.33499F, 0.99999F, 0.99999F, 0.77419F, 0.77419F, 0.47760F, 0.47760F, + 0.27636F, 0.27636F, 0.15681F, 0.15681F, 0.08843F, 0.08843F, 0.04977F, 0.04977F, + 0.02800F, 0.02800F, 0.01574F, 0.01574F, 0.00885F, 0.00885F, 0.00498F, 0.00498F}, + {-0.66363F, -0.66363F, -0.56459F, -0.56459F, 0.25145F, 0.25145F, -0.90277F, -0.90277F, + 0.23925F, 0.23925F, 0.99820F, 0.99820F, 0.79382F, 0.79382F, 0.49314F, 0.49314F, + 0.28595F, 0.28595F, 0.16236F, 0.16236F, 0.09158F, 0.09158F, 0.05155F, 0.05155F, + 0.02900F, 0.02900F, 0.01631F, 0.01631F, 0.00917F, 0.00917F, 0.00516F, 0.00516F}, + {-0.98803F, -0.98803F, -0.91771F, -0.91771F, -0.06201F, -0.06201F, -0.81245F, -0.81245F, + 0.14112F, 0.14112F, 0.99325F, 0.99325F, 0.81265F, 0.81265F, 0.50854F, 0.50854F, + 0.29552F, 0.29552F, 0.16790F, 0.16790F, 0.09473F, 0.09473F, 0.05332F, 0.05332F, + 0.03000F, 0.03000F, 0.01687F, 0.01687F, 0.00949F, 0.00949F, 0.00533F, 0.00533F}, + {-0.40404F, -0.40404F, -0.98819F, -0.98819F, -0.36933F, -0.36933F, -0.69651F, -0.69651F, + 0.04158F, 0.04158F, 0.98517F, 0.98517F, 0.83067F, 0.83067F, 0.52377F, 0.52377F, + 0.30506F, 0.30506F, 0.17344F, 0.17344F, 0.09787F, 0.09787F, 0.05510F, 0.05510F, + 0.03100F, 0.03100F, 0.01743F, 0.01743F, 0.00980F, 0.00980F, 0.00551F, 0.00551F}}}}; + xt::xarray expected_trans_mat = { + {{{0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {-1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 1.00000F}, + {0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, -1.00000F, 0.00000F}}}}; + + auto rope_params = ops::build_rope_params( + /*sequence_length=*/32, + /*head_dim=*/32); + + EXPECT_TRUE(xt::allclose(expected_cos, core::to_xtensor(rope_params.cos_cache), /*rtol=*/0.01F, /*atol=*/0.03F)); + EXPECT_TRUE(xt::allclose(expected_sin, core::to_xtensor(rope_params.sin_cache), /*rtol=*/0.01F, /*atol=*/0.03F)); + EXPECT_TRUE(xt::allclose(expected_trans_mat, core::to_xtensor(rope_params.trans_mat))); +} + +TEST_F(RoPETest, ForwardTest) { + // Head dim must be a multiple of TILE_WIDTH + // Head dim must be <= 256 + + // Input query tensor + xt::xarray xq = xt::ones({1, 2, 5, 32}); + xt::xarray expected_xq_out = { + {{{1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.30078F, 1.38281F, 0.31641F, 1.38281F, 0.64062F, 1.25781F, 0.80859F, 1.16406F, + 0.89844F, 1.09375F, 0.94531F, 1.05469F, 0.96875F, 1.03125F, 0.98438F, 1.01562F, + 0.99219F, 1.00781F, 0.99609F, 1.00781F, 0.99609F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-1.32812F, 0.49414F, -0.47070F, 1.33594F, 0.21484F, 1.39844F, 0.58984F, 1.28906F, + 0.78125F, 1.17969F, 0.87891F, 1.10156F, 0.93359F, 1.06250F, 0.96484F, 1.03906F, + 0.98047F, 1.02344F, 0.98828F, 1.01562F, 0.99609F, 1.00781F, 0.99609F, 1.00781F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-1.13281F, -0.84766F, -1.10938F, 0.87500F, -0.23047F, 1.39844F, 0.35156F, 1.36719F, + 0.66406F, 1.25000F, 0.81641F, 1.15625F, 0.90234F, 1.09375F, 0.94531F, 1.05469F, + 0.96875F, 1.03125F, 0.98438F, 1.01562F, 0.99219F, 1.00781F, 0.99609F, 1.00781F, + 0.99609F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {0.10547F, -1.41406F, -1.40625F, 0.14844F, -0.65234F, 1.25781F, 0.10547F, 1.41406F, + 0.53516F, 1.31250F, 0.75391F, 1.20312F, 0.86719F, 1.11719F, 0.92578F, 1.07031F, + 0.96094F, 1.03906F, 0.97656F, 1.02344F, 0.98828F, 1.01562F, 0.99219F, 1.00781F, + 0.99609F, 1.00781F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}}, + {{1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-0.30078F, 1.38281F, 0.31641F, 1.38281F, 0.64062F, 1.25781F, 0.80859F, 1.16406F, + 0.89844F, 1.09375F, 0.94531F, 1.05469F, 0.96875F, 1.03125F, 0.98438F, 1.01562F, + 0.99219F, 1.00781F, 0.99609F, 1.00781F, 0.99609F, 1.00000F, 1.00000F, 1.00000F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-1.32812F, 0.49414F, -0.47070F, 1.33594F, 0.21484F, 1.39844F, 0.58984F, 1.28906F, + 0.78125F, 1.17969F, 0.87891F, 1.10156F, 0.93359F, 1.06250F, 0.96484F, 1.03906F, + 0.98047F, 1.02344F, 0.98828F, 1.01562F, 0.99609F, 1.00781F, 0.99609F, 1.00781F, + 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {-1.13281F, -0.84766F, -1.10938F, 0.87500F, -0.23047F, 1.39844F, 0.35156F, 1.36719F, + 0.66406F, 1.25000F, 0.81641F, 1.15625F, 0.90234F, 1.09375F, 0.94531F, 1.05469F, + 0.96875F, 1.03125F, 0.98438F, 1.01562F, 0.99219F, 1.00781F, 0.99609F, 1.00781F, + 0.99609F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}, + {0.10547F, -1.41406F, -1.40625F, 0.14844F, -0.65234F, 1.25781F, 0.10547F, 1.41406F, + 0.53516F, 1.31250F, 0.75391F, 1.20312F, 0.86719F, 1.11719F, 0.92578F, 1.07031F, + 0.96094F, 1.03906F, 0.97656F, 1.02344F, 0.98828F, 1.01562F, 0.99219F, 1.00781F, + 0.99609F, 1.00781F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F, 1.00000F}}}}; + + auto* device = &ttml::autograd::ctx().get_device(); + + // Call the RoPE function + auto rope_params = ops::build_rope_params( + /*sequence_length=*/5, + /*head_dim=*/32); + auto rope_mod = RotaryEmbedding(rope_params); + + auto xq_autograd_tensor = autograd::create_tensor(core::from_xtensor(xq, device)); + + auto actual_xq_out = rope_mod(xq_autograd_tensor); + + auto actual_xq_out_xt = core::to_xtensor(actual_xq_out->get_value()); + + // Check that outputs match the expected values + EXPECT_TRUE(xt::allclose(actual_xq_out_xt, expected_xq_out, 2e-1, 2e-1)); +} + +TEST_F(RoPETest, BackwardTest) { + // Head dim must be a multiple of TILE_WIDTH + // Head dim must be <= 256 + // Input query tensor + xt::xarray xq = xt::ones({1, 2, 5, 32}); + xt::xarray expected_grad = { + {{{0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {-0.00238F, 0.00812F, -0.00238F, 0.00433F, -0.00163F, 0.00223F, -0.00100F, 0.00123F, + -0.00060F, 0.00065F, -0.00035F, 0.00036F, -0.00019F, 0.00020F, -0.00012F, 0.00010F, + -0.00007F, 0.00005F, -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00319F, 0.01459F, -0.00208F, 0.00922F, -0.00250F, 0.00488F, -0.00177F, 0.00259F, + -0.00111F, 0.00137F, -0.00065F, 0.00076F, -0.00039F, 0.00042F, -0.00021F, 0.00020F, + -0.00012F, 0.00015F, -0.00007F, 0.00005F, -0.00005F, 0.00005F, -0.00002F, 0.00000F, + -0.00002F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.01154F, 0.01331F, 0.00077F, 0.01318F, -0.00244F, 0.00769F, -0.00232F, 0.00403F, + -0.00157F, 0.00212F, -0.00097F, 0.00115F, -0.00055F, 0.00064F, -0.00032F, 0.00036F, + -0.00019F, 0.00020F, -0.00010F, 0.00010F, -0.00005F, 0.00005F, -0.00002F, 0.00005F, + -0.00002F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.01508F, 0.00558F, 0.00531F, 0.01508F, -0.00159F, 0.01038F, -0.00255F, 0.00562F, + -0.00194F, 0.00294F, -0.00125F, 0.00154F, -0.00073F, 0.00083F, -0.00043F, 0.00047F, + -0.00023F, 0.00025F, -0.00014F, 0.00015F, -0.00007F, 0.00010F, -0.00005F, 0.00005F, + -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}}, + {{0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {-0.00238F, 0.00812F, -0.00238F, 0.00433F, -0.00163F, 0.00223F, -0.00100F, 0.00123F, + -0.00060F, 0.00065F, -0.00035F, 0.00036F, -0.00019F, 0.00020F, -0.00012F, 0.00010F, + -0.00007F, 0.00005F, -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F, 0.00000F, + 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.00319F, 0.01459F, -0.00208F, 0.00922F, -0.00250F, 0.00488F, -0.00177F, 0.00259F, + -0.00111F, 0.00137F, -0.00065F, 0.00076F, -0.00039F, 0.00042F, -0.00021F, 0.00020F, + -0.00012F, 0.00015F, -0.00007F, 0.00005F, -0.00005F, 0.00005F, -0.00002F, 0.00000F, + -0.00002F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.01154F, 0.01331F, 0.00077F, 0.01318F, -0.00244F, 0.00769F, -0.00232F, 0.00403F, + -0.00157F, 0.00212F, -0.00097F, 0.00115F, -0.00055F, 0.00064F, -0.00032F, 0.00036F, + -0.00019F, 0.00020F, -0.00010F, 0.00010F, -0.00005F, 0.00005F, -0.00002F, 0.00005F, + -0.00002F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}, + {0.01508F, 0.00558F, 0.00531F, 0.01508F, -0.00159F, 0.01038F, -0.00255F, 0.00562F, + -0.00194F, 0.00294F, -0.00125F, 0.00154F, -0.00073F, 0.00083F, -0.00043F, 0.00047F, + -0.00023F, 0.00025F, -0.00014F, 0.00015F, -0.00007F, 0.00010F, -0.00005F, 0.00005F, + -0.00002F, 0.00005F, -0.00002F, 0.00000F, 0.00000F, 0.00000F, 0.00000F, 0.00000F}}}}; + + auto* device = &ttml::autograd::ctx().get_device(); + auto rope_params = ops::build_rope_params( + /*sequence_length=*/5, + /*head_dim=*/32); + auto rope_mod = modules::RotaryEmbedding(rope_params); + + auto xq_autograd_tensor = autograd::create_tensor(core::from_xtensor(xq, device)); + + auto actual_xq_out = rope_mod(xq_autograd_tensor); + auto target = autograd::create_tensor(core::from_xtensor(xq, device)); // just need ones for mse target, reusing xq + + auto loss = ttml::ops::mse_loss(actual_xq_out, target); + loss->backward(); + + auto actual_grad = core::to_xtensor(xq_autograd_tensor->get_grad()); + EXPECT_TRUE(xt::allclose(actual_grad, expected_grad, 2e-1, 2e-1)); +} + +} // namespace ttml::modules::tests diff --git a/ttnn/cpp/ttnn/tensor/xtensor/xtensor_all_includes.hpp b/ttnn/cpp/ttnn/tensor/xtensor/xtensor_all_includes.hpp index 12bdd2addb8..ca5b90817bd 100644 --- a/ttnn/cpp/ttnn/tensor/xtensor/xtensor_all_includes.hpp +++ b/ttnn/cpp/ttnn/tensor/xtensor/xtensor_all_includes.hpp @@ -16,3 +16,5 @@ #include #include #include +#include +#include