-
Notifications
You must be signed in to change notification settings - Fork 117
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#15061: Implement multi-device tensor distribution APIs in terms of C…
…++ ttnn tensors (#15886) ### Ticket #15755 ### Problem description Multi-device tensor distribution currently works through `distributed.py`, which relies on PyTorch libraries to perform sharding / concatenation. ### What's changed * Add xtensor to ttnn. * Lower facilities from tt-train down to ttnn. In particular: `chunk`, `concatenate` functions along with some conversion utils, and the relevant tests. * Add `distributed_tensor.hpp` header with the multi-device distribution APIs. **In follow up PRs:** * Support bf4 / bf8 and other formats in `from_vector` / `to_vector` and other overloads. * Support outputting a tilized tensor. * Migrate functionality from `pytensor.cpp` to using the new APIs. ### Checklist - [x] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/12333746639/job/34427015707) (failure in clang-tidy in unreleated tt-train directory) - [X] [code analysis run](https://github.com/tenstorrent/tt-metal/actions/runs/12360844971) - [x] [T3K unit + frequent + model reg tests](https://github.com/tenstorrent/tt-metal/actions/runs/12360656141) - same breakage on main. - [X] New/Existing tests provide coverage for changes
- Loading branch information
1 parent
10d11b6
commit 60f2d28
Showing
31 changed files
with
1,265 additions
and
422 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
186 changes: 186 additions & 0 deletions
186
tests/ttnn/unit_tests/gtests/tensor/test_distributed_tensor.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <gtest/gtest.h> | ||
#include <gmock/gmock.h> | ||
|
||
#include "ttnn/distributed/api.hpp" | ||
#include "ttnn/operations/functions.hpp" | ||
#include "ttnn_test_fixtures.hpp" | ||
#include <ttnn/distributed/types.hpp> | ||
#include <ttnn/distributed/distributed_tensor.hpp> | ||
|
||
namespace ttnn::distributed::test { | ||
|
||
using ::testing::ElementsAre; | ||
|
||
using TensorDistributionTest = T3kMultiDeviceFixture; | ||
|
||
TensorSpec get_tensor_spec(const ttnn::SimpleShape& shape, DataType dtype) { | ||
return TensorSpec(shape, TensorLayout(dtype, Layout::ROW_MAJOR, MemoryConfig{})); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Replication) { | ||
Tensor input_tensor = Tensor::from_vector( | ||
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32)); | ||
|
||
auto mapper = replicate_tensor_to_mesh_mapper(*mesh_device_); | ||
Tensor replicated_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
|
||
std::vector<Tensor> device_tensors = get_device_tensors(replicated_tensor); | ||
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); | ||
for (const auto& device_tensor : device_tensors) { | ||
EXPECT_THAT(device_tensor.to_vector<float>(), ElementsAre(42.F, 13.F, -99.F)); | ||
} | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard1DInvalidDim) { | ||
const int num_devices = mesh_device_->num_devices(); | ||
Tensor input_tensor = Tensor::from_vector( | ||
std::vector<float>(num_devices, 0), | ||
get_tensor_spec(ttnn::SimpleShape{1, 1, 1, num_devices}, DataType::FLOAT32)); | ||
|
||
EXPECT_ANY_THROW({ | ||
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, -1); | ||
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
}); | ||
|
||
EXPECT_ANY_THROW({ | ||
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 4); | ||
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
}); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard1DTooFewShards) { | ||
const int num_devices = mesh_device_->num_devices(); | ||
ASSERT_LT(3, num_devices); | ||
Tensor input_tensor = Tensor::from_vector( | ||
std::vector<float>{42.F, 13.F, -99.F}, get_tensor_spec(ttnn::SimpleShape{1, 1, 1, 3}, DataType::FLOAT32)); | ||
|
||
EXPECT_ANY_THROW({ | ||
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 3); | ||
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
}); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard1D) { | ||
const int num_devices = mesh_device_->num_devices(); | ||
std::vector<float> test_data; | ||
for (int i = 0; i < num_devices; i++) { | ||
test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F}); | ||
} | ||
Tensor input_tensor = | ||
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_devices, 3, 1}, DataType::FLOAT32)); | ||
|
||
auto mapper = shard_tensor_to_mesh_mapper(*mesh_device_, 1); | ||
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
|
||
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor); | ||
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); | ||
for (int i = 0; i < device_tensors.size(); i++) { | ||
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); | ||
} | ||
|
||
auto composer = concat_mesh_to_tensor_composer(/*dim=*/0); | ||
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer); | ||
|
||
Tensor expected_tensor = | ||
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_devices, 1, 3, 1}, DataType::FLOAT32)); | ||
EXPECT_TRUE(ttnn::allclose<float>(concatenated_tensor, expected_tensor)); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard2DInvalidMeshShape) { | ||
const auto [num_rows, num_cols] = mesh_device_->shape(); | ||
ASSERT_EQ(num_rows, 2); | ||
ASSERT_EQ(num_cols, 4); | ||
|
||
EXPECT_ANY_THROW( | ||
shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{3, 1}, Shard2dConfig{.row_dim = 1, .col_dim = 2})); | ||
|
||
EXPECT_ANY_THROW( | ||
shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 5}, Shard2dConfig{.row_dim = 1, .col_dim = 2})); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard2DInvalidShardConfig) { | ||
EXPECT_ANY_THROW(shard_tensor_to_2d_mesh_mapper(*mesh_device_, MeshShape{2, 4}, Shard2dConfig{})); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Concat2DInvalidConfig) { | ||
EXPECT_ANY_THROW(concat_2d_mesh_to_tensor_composer(*mesh_device_, Concat2dConfig{.row_dim = 2, .col_dim = 2})); | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard2DReplicateDim) { | ||
const auto [num_rows, num_cols] = mesh_device_->shape(); | ||
ASSERT_EQ(num_rows, 2); | ||
ASSERT_EQ(num_cols, 4); | ||
const int num_devices = num_rows * num_cols; | ||
|
||
std::vector<float> test_data = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}; | ||
Tensor input_tensor = | ||
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 1}, DataType::FLOAT32)); | ||
input_tensor.print(); | ||
|
||
auto mapper = shard_tensor_to_2d_mesh_mapper( | ||
*mesh_device_, | ||
MeshShape{num_rows, num_cols}, | ||
Shard2dConfig{ | ||
.row_dim = 1, | ||
}); | ||
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
sharded_tensor.print(); | ||
|
||
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor); | ||
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); | ||
|
||
int i = 0; | ||
for (; i < 4; i++) { | ||
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(0.0, 1.0, 2.0, 3.0)); | ||
} | ||
for (; i < device_tensors.size(); i++) { | ||
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(4.0, 5.0, 6.0, 7.0)); | ||
} | ||
} | ||
|
||
TEST_F(TensorDistributionTest, Shard2D) { | ||
const auto [num_rows, num_cols] = mesh_device_->shape(); | ||
ASSERT_EQ(num_rows, 2); | ||
ASSERT_EQ(num_cols, 4); | ||
const int num_devices = num_rows * num_cols; | ||
|
||
std::vector<float> test_data; | ||
for (int i = 0; i < num_devices; i++) { | ||
test_data.insert(test_data.end(), {i * 1.F, i * 2.F, i * 3.F}); | ||
} | ||
Tensor input_tensor = | ||
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{1, num_rows, num_cols, 3}, DataType::FLOAT32)); | ||
|
||
auto mapper = shard_tensor_to_2d_mesh_mapper( | ||
*mesh_device_, | ||
MeshShape{num_rows, num_cols}, | ||
Shard2dConfig{ | ||
.row_dim = 1, | ||
.col_dim = 2, | ||
}); | ||
Tensor sharded_tensor = distribute_tensor(input_tensor, *mesh_device_, *mapper); | ||
|
||
std::vector<Tensor> device_tensors = get_device_tensors(sharded_tensor); | ||
EXPECT_EQ(device_tensors.size(), mesh_device_->num_devices()); | ||
for (int i = 0; i < device_tensors.size(); i++) { | ||
EXPECT_THAT(device_tensors[i].to_vector<float>(), ElementsAre(i * 1.F, i * 2.F, i * 3.F)); | ||
} | ||
|
||
auto composer = concat_2d_mesh_to_tensor_composer( | ||
*mesh_device_, | ||
Concat2dConfig{ | ||
.row_dim = 0, | ||
.col_dim = 2, | ||
}); | ||
Tensor concatenated_tensor = aggregate_tensor(sharded_tensor, *composer); | ||
|
||
Tensor expected_tensor = | ||
Tensor::from_vector(test_data, get_tensor_spec(ttnn::SimpleShape{num_rows, 1, num_cols, 3}, DataType::FLOAT32)); | ||
EXPECT_TRUE(ttnn::allclose<float>(concatenated_tensor, expected_tensor)); | ||
} | ||
|
||
} // namespace ttnn::distributed::test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <gtest/gtest.h> | ||
#include <gmock/gmock.h> | ||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/tensor/xtensor/conversion_utils.hpp" | ||
#include "ttnn/tensor/xtensor/partition.hpp" | ||
#include "ttnn/tensor/xtensor/xtensor_all_includes.hpp" | ||
|
||
namespace ttnn { | ||
namespace { | ||
|
||
using ::testing::SizeIs; | ||
using ::tt::tt_metal::Tensor; | ||
using ::ttnn::experimental::xtensor::chunk; | ||
using ::ttnn::experimental::xtensor::concat; | ||
|
||
TEST(PartitionTest, ChunkBasicNonDivisible3) { | ||
// Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | ||
xt::xarray<float> tensor = xt::arange<float>(10); | ||
|
||
// Chunk into 3 parts along dimension 0 | ||
auto chunks = chunk(tensor, 3, 0); | ||
|
||
ASSERT_THAT(chunks, SizeIs(3)); | ||
EXPECT_EQ(chunks[0].shape()[0], 4u); // first chunk size 4 | ||
EXPECT_EQ(chunks[1].shape()[0], 4u); // next chunk size 4 | ||
EXPECT_EQ(chunks[2].shape()[0], 2u); // last chunk size 2 | ||
} | ||
|
||
TEST(PartitionTest, ChunkBasicLessChunksThanProvided) { | ||
// Create a 1D tensor: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12] | ||
xt::xarray<float> tensor = xt::arange<float>(13); | ||
|
||
// Chunk into 6 parts along dimension 0 | ||
auto chunks = chunk(tensor, 6, 0); | ||
|
||
ASSERT_THAT(chunks, SizeIs(5)); | ||
EXPECT_EQ(chunks[0].shape()[0], 3u); // first chunk size 3 | ||
EXPECT_EQ(chunks[1].shape()[0], 3u); // next chunk size 3 | ||
EXPECT_EQ(chunks[2].shape()[0], 3u); // next chunk size 3 | ||
EXPECT_EQ(chunks[3].shape()[0], 3u); // next chunk size 3 | ||
EXPECT_EQ(chunks[4].shape()[0], 1u); // last chunk size 1 | ||
} | ||
|
||
TEST(PartitionTest, DefaultAxis) { | ||
xt::xarray<double> a = {{1.0, 2.0}, {3.0, 4.0}}; | ||
xt::xarray<double> b = {{5.0, 6.0}, {7.0, 8.0}}; | ||
std::vector<xt::xarray<double>> input = {a, b}; | ||
|
||
xt::xarray<double> result = concat(input); // axis=0 by default | ||
xt::xarray<double> expected = {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}, {7.0, 8.0}}; | ||
|
||
xt::allclose(result, expected); | ||
} | ||
|
||
TEST(PartitionTest, AxisOne) { | ||
xt::xarray<int> x = {{1, 2, 3}, {4, 5, 6}}; | ||
xt::xarray<int> y = {{7, 8}, {9, 10}}; | ||
std::vector<xt::xarray<int>> input = {x, y}; | ||
|
||
xt::xarray<int> result = concat(input, 1); | ||
xt::xarray<int> expected = {{1, 2, 3, 7, 8}, {4, 5, 6, 9, 10}}; | ||
|
||
xt::allclose(result, expected); | ||
} | ||
|
||
TEST(PartitionTest, MultipleArraysAxis0) { | ||
xt::xarray<float> a = {1.0f, 2.0f}; | ||
xt::xarray<float> b = {3.0f, 4.0f}; | ||
xt::xarray<float> c = {5.0f, 6.0f}; | ||
std::vector<xt::xarray<float>> input = {a, b, c}; | ||
|
||
xt::xarray<float> result = concat(input, 0); | ||
xt::xarray<float> expected = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; | ||
|
||
xt::allclose(result, expected); | ||
} | ||
|
||
TEST(PartitionTest, EmptyArray) { | ||
xt::xarray<int> a = {{1, 2}, {3, 4}}; | ||
xt::xarray<int> b; // Empty | ||
std::vector<xt::xarray<int>> input = {a, b}; | ||
|
||
EXPECT_ANY_THROW({ xt::xarray<int> result = concat(input, 0); }); | ||
} | ||
|
||
TEST(PartitionTest, HigherDimensions) { | ||
xt::xarray<int> arr1 = xt::arange<int>(1, 9); // 1 to 8 | ||
arr1.reshape({2, 2, 2}); | ||
xt::xarray<int> arr2 = xt::arange<int>(9, 17); // 9 to 16 | ||
arr2.reshape({2, 2, 2}); | ||
|
||
std::vector<xt::xarray<int>> input = {arr1, arr2}; | ||
xt::xarray<int> result = concat(input, 0); | ||
|
||
// Expected: shape (4,2,2) with arr1 stacked over arr2 along axis 0 | ||
xt::xarray<int> expected = xt::concatenate(xt::xtuple(arr1, arr2), 0); | ||
|
||
xt::allclose(result, expected); | ||
} | ||
|
||
TEST(PartitionTest, HigherAxis) { | ||
xt::xarray<int> arr1 = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}; | ||
xt::xarray<int> arr2 = {{{9, 10}, {11, 12}}, {{13, 14}, {15, 16}}}; | ||
// Both have shape (2,2,2) | ||
|
||
std::vector<xt::xarray<int>> input = {arr1, arr2}; | ||
xt::xarray<int> result = concat(input, 2); | ||
// Expected shape: (2,2,4) | ||
xt::xarray<int> expected = {{{1, 2, 9, 10}, {3, 4, 11, 12}}, {{5, 6, 13, 14}, {7, 8, 15, 16}}}; | ||
|
||
xt::allclose(result, expected); | ||
} | ||
|
||
} // namespace | ||
} // namespace ttnn |
Oops, something went wrong.