Skip to content

Commit

Permalink
#15061: Implement multi-device tensor distribution APIs in terms of C…
Browse files Browse the repository at this point in the history
…++ ttnn tensors (#15886)

### Ticket
#15755

### Problem description
Multi-device tensor distribution currently works through
`distributed.py`, which relies on PyTorch libraries to perform sharding
/ concatenation.

### What's changed
* Add xtensor to ttnn.
* Lower facilities from tt-train down to ttnn. In particular: `chunk`,
`concatenate` functions along with some conversion utils, and the
relevant tests.
* Add `distributed_tensor.hpp` header with the multi-device distribution
APIs.

**In follow up PRs:**
* Support bf4 / bf8 and other formats in `from_vector` / `to_vector` and
other overloads.
* Support outputting a tilized tensor.
* Migrate functionality from `pytensor.cpp` to using the new APIs.

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12333746639/job/34427015707)
(failure in clang-tidy in unreleated tt-train directory)
- [X] [code analysis
run](https://github.com/tenstorrent/tt-metal/actions/runs/12360844971)
- [x] [T3K unit + frequent + model reg
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12360656141)
- same breakage on main.
- [X] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Dec 17, 2024
1 parent 10d11b6 commit 60f2d28
Show file tree
Hide file tree
Showing 31 changed files with 1,265 additions and 422 deletions.
14 changes: 14 additions & 0 deletions dependencies/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,17 @@ CPMAddPackage(NAME pybind11 GITHUB_REPOSITORY pybind/pybind11 GIT_TAG b8f28551cc
############################################################################################################################

CPMAddPackage(NAME json GITHUB_REPOSITORY nlohmann/json GIT_TAG v3.9.1)

############################################################################################################################
# xtensor : https://github.com/xtensor-stack/xtensor
############################################################################################################################

CPMAddPackage(NAME xtl GITHUB_REPOSITORY xtensor-stack/xtl GIT_TAG 0.7.7 OPTIONS "XTL_ENABLE_TESTS OFF")
CPMAddPackage(NAME xtensor GITHUB_REPOSITORY xtensor-stack/xtensor GIT_TAG 0.25.0 OPTIONS "XTENSOR_ENABLE_TESTS OFF")
CPMAddPackage(
NAME xtensor-blas
GITHUB_REPOSITORY xtensor-stack/xtensor-blas
GIT_TAG 0.21.0
OPTIONS
"XTENSOR_ENABLE_TESTS OFF"
)
4 changes: 4 additions & 0 deletions tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_multi_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_with_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_distributed_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_sharding_with_alignment.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_vector_conversion.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_xtensor_conversion.cpp
)

add_executable(unit_tests_ttnn ${TTNN_UNIT_TESTS_SRC})
Expand Down
186 changes: 186 additions & 0 deletions tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "ttnn/distributed/api.hpp"
#include "ttnn/operations/functions.hpp"
#include "ttnn_test_fixtures.hpp"
#include <ttnn/distributed/types.hpp>
#include <ttnn/distributed/distributed_tensor.hpp>

namespace ttnn::distributed::test {

using ::testing::ElementsAre;

using TensorDistributionTest = T3kMultiDeviceFixture;

TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype) {
return TensorSpec(shape, TensorLayout(dtype, Layout::ROW_MAJOR, MemoryConfig{}));
}

TEST_F(TensorDistributionTest, Replication) {
Tensor input_tensor = Tensor::from_vector(
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_);
Tensor replicated_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = get_device_tensors(replicated_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (const auto& device_tensor : device_tensors) {
EXPECT_THAT(device_tensor.to_vector<float>(), ElementsAre(42.F, 13.F, -99.F));
}
}

TEST_F(TensorDistributionTest, Shard1DInvalidDim) {
const int num_devices = mesh_device_->num_devices();
Tensor input_tensor = Tensor::from_vector(
std::vector<float>(num_devices, 0),
get_tensor_spec(ttnn::SimpleShape{1, 1, 1, num_devices}, DataType::FLOAT32));

EXPECT_ANY_THROW({
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, -1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
});

EXPECT_ANY_THROW({
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 4);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
});
}

TEST_F(TensorDistributionTest, Shard1DTooFewShards) {
const int num_devices = mesh_device_->num_devices();
ASSERT_LT(3, num_devices);
Tensor input_tensor = Tensor::from_vector(
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32));

EXPECT_ANY_THROW({
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 3);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
});
}

TEST_F(TensorDistributionTest, Shard1D) {
const int num_devices = mesh_device_->num_devices();
std::vector<float> test_data;
for (int i = 0; i < num_devices; i++) {
test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F});
}
Tensor input_tensor =
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32));

auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1);
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (int i = 0; i < device_tensors.size(); i++) {
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = concat_mesh_to_tensor_composer(/*dim=*/0);
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer);

Tensor expected_tensor =
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32));
EXPECT_TRUE(ttnn::allclose<float>(concatenated_tensor, expected_tensor));
}

TEST_F(TensorDistributionTest, Shard2DInvalidMeshShape) {
const auto [num_rows, num_cols] = mesh_device_->shape();
ASSERT_EQ(num_rows, 2);
ASSERT_EQ(num_cols, 4);

EXPECT_ANY_THROW(
shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));

EXPECT_ANY_THROW(
shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2}));
}

TEST_F(TensorDistributionTest, Shard2DInvalidShardConfig) {
EXPECT_ANY_THROW(shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{}));
}

TEST_F(TensorDistributionTest, Concat2DInvalidConfig) {
EXPECT_ANY_THROW(concat_2d_mesh_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2}));
}

TEST_F(TensorDistributionTest, Shard2DReplicateDim) {
const auto [num_rows, num_cols] = mesh_device_->shape();
ASSERT_EQ(num_rows, 2);
ASSERT_EQ(num_cols, 4);
const int num_devices = num_rows * num_cols;

std::vector<float> test_data = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
Tensor input_tensor =
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32));
input_tensor.print();

auto mapper = shard_tensor_to_2d_mesh_mapper(
*mesh_device_,
MeshShape{num_rows, num_cols},
Shard2dConfig{
.row_dim = 1,
});
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);
sharded_tensor.print();

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());

int i = 0;
for (; i < 4; i++) {
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(0.0, 1.0, 2.0, 3.0));
}
for (; i < device_tensors.size(); i++) {
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(4.0, 5.0, 6.0, 7.0));
}
}

TEST_F(TensorDistributionTest, Shard2D) {
const auto [num_rows, num_cols] = mesh_device_->shape();
ASSERT_EQ(num_rows, 2);
ASSERT_EQ(num_cols, 4);
const int num_devices = num_rows * num_cols;

std::vector<float> test_data;
for (int i = 0; i < num_devices; i++) {
test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F});
}
Tensor input_tensor =
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32));

auto mapper = shard_tensor_to_2d_mesh_mapper(
*mesh_device_,
MeshShape{num_rows, num_cols},
Shard2dConfig{
.row_dim = 1,
.col_dim = 2,
});
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper);

std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor);
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices());
for (int i = 0; i < device_tensors.size(); i++) {
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F));
}

auto composer = concat_2d_mesh_to_tensor_composer(
*mesh_device_,
Concat2dConfig{
.row_dim = 0,
.col_dim = 2,
});
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer);

Tensor expected_tensor =
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32));
EXPECT_TRUE(ttnn::allclose<float>(concatenated_tensor, expected_tensor));
}

} // namespace ttnn::distributed::test
120 changes: 120 additions & 0 deletions tests/ttnn/unit_tests/gtests/tensor/test_partition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>
#include <gmock/gmock.h>

#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/xtensor/conversion_utils.hpp"
#include "ttnn/tensor/xtensor/partition.hpp"
#include "ttnn/tensor/xtensor/xtensor_all_includes.hpp"

namespace ttnn {
namespace {

using ::testing::SizeIs;
using ::tt::tt_metal::Tensor;
using ::ttnn::experimental::xtensor::chunk;
using ::ttnn::experimental::xtensor::concat;

TEST(PartitionTest, ChunkBasicNonDivisible3) {
// Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
xt::xarray<float> tensor = xt::arange<float>(10);

// Chunk into 3 parts along dimension 0
auto chunks = chunk(tensor, 3, 0);

ASSERT_THAT(chunks, SizeIs(3));
EXPECT_EQ(chunks[0].shape()[0], 4u); // first chunk size 4
EXPECT_EQ(chunks[1].shape()[0], 4u); // next chunk size 4
EXPECT_EQ(chunks[2].shape()[0], 2u); // last chunk size 2
}

TEST(PartitionTest, ChunkBasicLessChunksThanProvided) {
// Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12]
xt::xarray<float> tensor = xt::arange<float>(13);

// Chunk into 6 parts along dimension 0
auto chunks = chunk(tensor, 6, 0);

ASSERT_THAT(chunks, SizeIs(5));
EXPECT_EQ(chunks[0].shape()[0], 3u); // first chunk size 3
EXPECT_EQ(chunks[1].shape()[0], 3u); // next chunk size 3
EXPECT_EQ(chunks[2].shape()[0], 3u); // next chunk size 3
EXPECT_EQ(chunks[3].shape()[0], 3u); // next chunk size 3
EXPECT_EQ(chunks[4].shape()[0], 1u); // last chunk size 1
}

TEST(PartitionTest, DefaultAxis) {
xt::xarray<double> a = {{1.0, 2.0}, {3.0, 4.0}};
xt::xarray<double> b = {{5.0, 6.0}, {7.0, 8.0}};
std::vector<xt::xarray<double>> input = {a, b};

xt::xarray<double> result = concat(input); // axis=0 by default
xt::xarray<double> expected = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {7.0, 8.0}};

xt::allclose(result, expected);
}

TEST(PartitionTest, AxisOne) {
xt::xarray<int> x = {{1, 2, 3}, {4, 5, 6}};
xt::xarray<int> y = {{7, 8}, {9, 10}};
std::vector<xt::xarray<int>> input = {x, y};

xt::xarray<int> result = concat(input, 1);
xt::xarray<int> expected = {{1, 2, 3, 7, 8}, {4, 5, 6, 9, 10}};

xt::allclose(result, expected);
}

TEST(PartitionTest, MultipleArraysAxis0) {
xt::xarray<float> a = {1.0f, 2.0f};
xt::xarray<float> b = {3.0f, 4.0f};
xt::xarray<float> c = {5.0f, 6.0f};
std::vector<xt::xarray<float>> input = {a, b, c};

xt::xarray<float> result = concat(input, 0);
xt::xarray<float> expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

xt::allclose(result, expected);
}

TEST(PartitionTest, EmptyArray) {
xt::xarray<int> a = {{1, 2}, {3, 4}};
xt::xarray<int> b; // Empty
std::vector<xt::xarray<int>> input = {a, b};

EXPECT_ANY_THROW({ xt::xarray<int> result = concat(input, 0); });
}

TEST(PartitionTest, HigherDimensions) {
xt::xarray<int> arr1 = xt::arange<int>(1, 9); // 1 to 8
arr1.reshape({2, 2, 2});
xt::xarray<int> arr2 = xt::arange<int>(9, 17); // 9 to 16
arr2.reshape({2, 2, 2});

std::vector<xt::xarray<int>> input = {arr1, arr2};
xt::xarray<int> result = concat(input, 0);

// Expected: shape (4,2,2) with arr1 stacked over arr2 along axis 0
xt::xarray<int> expected = xt::concatenate(xt::xtuple(arr1, arr2), 0);

xt::allclose(result, expected);
}

TEST(PartitionTest, HigherAxis) {
xt::xarray<int> arr1 = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
xt::xarray<int> arr2 = {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}};
// Both have shape (2,2,2)

std::vector<xt::xarray<int>> input = {arr1, arr2};
xt::xarray<int> result = concat(input, 2);
// Expected shape: (2,2,4)
xt::xarray<int> expected = {{{1, 2, 9, 10}, {3, 4, 11, 12}}, {{5, 6, 13, 14}, {7, 8, 15, 16}}};

xt::allclose(result, expected);
}

} // namespace
} // namespace ttnn
Loading

0 comments on commit 60f2d28

Please sign in to comment.