Skip to content

Commit

Permalink
[tt-train] Add RoPE op and module (#18465)
Browse files Browse the repository at this point in the history
### Problem description
- We need RoPE for Llama 3

### What's changed
- Adds TTML op with forward/backward for RoPE
- Adds TTML module for rotary embedding, using the RoPE op.
- Adds forward/backward tests for the RoPE module.

### Tests
- [x] New tests pass
  • Loading branch information
jaykru-tt authored Mar 6, 2025
1 parent 24444c3 commit 2ad02e3
Show file tree
Hide file tree
Showing 9 changed files with 782 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tt-train/sources/ttml/core/ttnn_all_includes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,6 @@
#include <ttnn/tensor/xtensor/partition.hpp> // NOLINT
#include <ttnn/tensor/xtensor/xtensor_all_includes.hpp> // NOLINT
#include <ttnn/types.hpp> // NOLINT

#include "ttnn/operations/experimental/transformer/rotary_embedding_llama/rotary_embedding_llama.hpp" // NOLINT
#pragma GCC diagnostic pop
19 changes: 19 additions & 0 deletions tt-train/sources/ttml/modules/rotary_embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "modules/rotary_embedding.hpp"

#include "autograd/auto_context.hpp"
#include "ops/rope_op.hpp"

namespace ttml::modules {

RotaryEmbedding::RotaryEmbedding(const ops::RotaryEmbeddingParams& rope_params) : m_rope_params(rope_params) {
}

autograd::TensorPtr RotaryEmbedding::operator()(const autograd::TensorPtr& input) {
return ttml::ops::rope(input, m_rope_params);
}

} // namespace ttml::modules
20 changes: 20 additions & 0 deletions tt-train/sources/ttml/modules/rotary_embedding.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "autograd/module_base.hpp"
#include "ops/rope_op.hpp"

namespace ttml::modules {
class RotaryEmbedding : public autograd::ModuleBase {
private:
ops::RotaryEmbeddingParams m_rope_params;

public:
explicit RotaryEmbedding(const ops::RotaryEmbeddingParams &rope_params);
[[nodiscard]] autograd::TensorPtr operator()(const autograd::TensorPtr &input) override;
};

} // namespace ttml::modules
167 changes: 167 additions & 0 deletions tt-train/sources/ttml/ops/rope_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ops/rope_op.hpp"

#include <numbers>

#include "autograd/auto_context.hpp"
#include "autograd/graph.hpp"
#include "autograd/graph_utils.hpp"
#include "autograd/tensor.hpp"
#include "core/tt_tensor_utils.hpp"
#include "core/ttnn_all_includes.hpp"
#include "core/xtensor_utils.hpp"
#include "ttnn/tensor/xtensor/xtensor_all_includes.hpp"
#include "ttnn_fixed/trivial_ttnn_ops.hpp"

namespace ttml::ops {

void validate_rope_input_and_params(const autograd::TensorPtr& input, const RotaryEmbeddingParams& params) {
if (input->get_rank() != 4U) {
throw std::runtime_error(
fmt::format("RoPE only supports rank-4 input tensors, but got rank {}.", input->get_rank()));
}
auto input_shape = input->get_shape();

auto input_seq_len = input_shape[-2];
auto input_head_dim = input_shape[-1];

if (input_head_dim != params.head_dim) {
throw std::runtime_error(fmt::format(
"RoPE input tensor's head dimension ({}) must match the head dimension in the params ({})",
input_head_dim,
params.head_dim));
}

if (input_seq_len != params.sequence_length) {
throw std::runtime_error(fmt::format(
"RoPE input tensor's sequence length ({}) must match the sequence length in the params ({})",
input_seq_len,
params.sequence_length));
}

auto trans_mat_shape = params.trans_mat.get_logical_shape();
auto trig_param_shapes = std::array{
params.cos_cache.get_logical_shape(),
params.sin_cache.get_logical_shape(),
params.neg_cos_cache.get_logical_shape(),
params.neg_sin_cache.get_logical_shape()};

auto expected_trig_shape = ttnn::Shape{1U, 1U, input_seq_len, input_head_dim};
if (!std::ranges::all_of(
trig_param_shapes, [&expected_trig_shape](auto shape) { return shape == expected_trig_shape; })) {
throw std::runtime_error(fmt::format(
"All trigonometric rotary embedding parameters must have shape [1, 1, {}, {}], but got shapes: "
"cos_cache: {}, sin_cache: {}, neg_cos_cache: {}, neg_sin_cache: {}",
input_seq_len,
input_head_dim,
params.cos_cache.get_logical_shape(),
params.sin_cache.get_logical_shape(),
params.neg_cos_cache.get_logical_shape(),
params.neg_sin_cache.get_logical_shape()));
}

auto expected_trans_mat_shape = ttnn::Shape{1U, 1U, 32U, 32U};
if (trans_mat_shape != expected_trans_mat_shape) {
throw std::runtime_error(fmt::format(
"RoPE trans_mat must be of shape {}, but has shape {}", expected_trans_mat_shape, trans_mat_shape));
}
}

// trans_mat, sin_cache, cos_cache are all precomputed and stored somewhere in
// the module hierarchy and passed to the operation.
autograd::TensorPtr rope(const autograd::TensorPtr& input, const RotaryEmbeddingParams& params) {
validate_rope_input_and_params(input, params);

auto out_tensor = ttnn::experimental::rotary_embedding_llama(
input->get_value(), params.cos_cache, params.sin_cache, params.trans_mat);
auto out = autograd::create_tensor(out_tensor);

// In the backward pass we rotate by -θ, so we need negated cos and sin
// caches. Note: we can just reuse trans_mat here since the data movement
// should be the same on the backward pass (we use the same trick to speed
// up the matmul, and the matrix used is specified by the cos/sin caches.)
autograd::GradFunction grad_fn = [input, params, out]() {
auto dL_dout = out->get_grad();

auto dL_dinput = ttnn::experimental::rotary_embedding_llama(
dL_dout, params.neg_cos_cache, params.neg_sin_cache, params.trans_mat);
input->add_grad(dL_dinput);
};

auto links = autograd::get_links(input);
out->set_node(autograd::ctx().add_backward_node(std::move(grad_fn), links));

return out;
}

std::pair<ttnn::Tensor, ttnn::Tensor> gen_freqs(uint32_t head_dim, uint32_t sequence_length, float theta = 10000.0F) {
int d = head_dim;
// compute freqs: 1.0 / (theta ** (2 * (i-1) / head_dim)) for i in [1, head_dim/2]
xt::xarray<uint32_t> expt_data = xt::arange(0, d) / 2;
xt::xarray<float> expt_xt = xt::cast<float>(expt_data);

expt_xt *= 2.0F / static_cast<float>(head_dim);
xt::xarray<float> theta_pow = xt::pow(theta, expt_xt);

auto freqs = xt::ones_like(theta_pow) / theta_pow;

xt::xarray<float> seq_pos = xt::arange<float>(sequence_length);
xt::xarray<float> seq_pos_repeated_to_head = xt::repeat(seq_pos, head_dim, seq_pos.dimension() - 1U);
xt::xarray<float> scales = seq_pos_repeated_to_head.reshape({sequence_length, static_cast<uint32_t>(head_dim)});

xt::xarray<float> scaled_freqs = scales * freqs;

// take the scaled freqs mod 2π to satisfy ttnn inputs constraints for sin/cos
auto pi = static_cast<float>(std::numbers::pi);
scaled_freqs = xt::fmod(scaled_freqs, 2.0F * pi);
scaled_freqs = scaled_freqs.reshape({1, 1, sequence_length, head_dim});

xt::xarray<float> sin_freqs = xt::sin(scaled_freqs);
xt::xarray<float> cos_freqs = xt::cos(scaled_freqs);

auto* device = &autograd::ctx().get_device();
return {core::from_xtensor(sin_freqs, device), core::from_xtensor(cos_freqs, device)};
}

ttnn::Tensor gen_trans_mat(int head_dim) {
xt::xarray<float> trans_mat = xt::zeros<float>({1, 1, head_dim, head_dim});
for (int i = 0; i < head_dim; i += 2) {
trans_mat(0, 0, i, i + 1) = 1.0F;
}
for (int j = 1; j < head_dim; j += 2) {
trans_mat(0, 0, j, j - 1) = -1.0F;
}

auto device = &autograd::ctx().get_device();
return core::from_xtensor(trans_mat, device);
}

RotaryEmbeddingParams build_rope_params(uint32_t sequence_length, uint32_t head_dim, float theta) {
if (head_dim % 32U != 0U) {
throw std::invalid_argument("RoPE head_dim must be divisible by 32");
}
if (head_dim > 256U) {
throw std::invalid_argument("RoPE head_dim must be less than or equal to 256");
}
if (head_dim <= 0U) {
throw std::invalid_argument("RoPE head_dim must be greater than 0");
}
auto [sin_freqs, cos_freqs] = gen_freqs(head_dim, sequence_length, theta);
auto trans_mat = gen_trans_mat(head_dim);

return {
.cos_cache = cos_freqs,
.sin_cache = sin_freqs,
.neg_cos_cache = cos_freqs, // cos(θ) = cos(-θ): symmetry over x-axis
.neg_sin_cache = ttnn::neg(sin_freqs), // sin(-θ) = -sin(θ)
.trans_mat = trans_mat,

.sequence_length = sequence_length,
.head_dim = head_dim,
};
}

} // namespace ttml::ops
29 changes: 29 additions & 0 deletions tt-train/sources/ttml/ops/rope_op.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "autograd/tensor.hpp"

namespace ttml::ops {

struct RotaryEmbeddingParams {
ttnn::Tensor cos_cache;
ttnn::Tensor sin_cache;
ttnn::Tensor neg_cos_cache;
ttnn::Tensor neg_sin_cache;
ttnn::Tensor trans_mat;

uint32_t sequence_length = 0;
uint32_t head_dim = 0;
};

autograd::TensorPtr rope(const autograd::TensorPtr& input, const RotaryEmbeddingParams& rope_params);

RotaryEmbeddingParams build_rope_params(uint32_t sequence_length, uint32_t head_dim, float theta = 10000.0F);
// Throws an exception if the input is bad, parameters are bad, or the two are
// incompatible with one another.
void validate_rope_input_and_params(const autograd::TensorPtr& input, const RotaryEmbeddingParams& rope_params);

} // namespace ttml::ops
8 changes: 8 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,12 @@ tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_
return ttnn::sum(t, dim, keep_dim, std::nullopt, core::ComputeKernelConfig::precise());
}

tt::tt_metal::Tensor to_l1_interleaved(const tt::tt_metal::Tensor& t) {
return ttnn::to_memory_config(t, ttnn::L1_MEMORY_CONFIG);
}

tt::tt_metal::Tensor to_dram_interleaved(const tt::tt_metal::Tensor& t) {
return ttnn::to_memory_config(t, ttnn::DRAM_MEMORY_CONFIG);
}

} // namespace ttml::ttnn_fixed
4 changes: 4 additions & 0 deletions tt-train/sources/ttml/ttnn_fixed/trivial_ttnn_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,8 @@ tt::tt_metal::Tensor mean_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep

tt::tt_metal::Tensor sum_moreh(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);
tt::tt_metal::Tensor sum_ttnn(const tt::tt_metal::Tensor& t, int dim, bool keep_dim);

tt::tt_metal::Tensor to_l1_interleaved(const tt::tt_metal::Tensor& t);
tt::tt_metal::Tensor to_dram_interleaved(const tt::tt_metal::Tensor& t);

} // namespace ttml::ttnn_fixed
Loading

0 comments on commit 2ad02e3

Please sign in to comment.