Skip to content

Commit

Permalink
Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
omilyutin-tt committed Mar 6, 2025
1 parent f307ccc commit 7538501
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 62 deletions.
215 changes: 166 additions & 49 deletions tests/ttnn/unit_tests/gtests/tensor/test_mesh_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ using ::testing::FloatEq;
using ::testing::Pointwise;

using MeshTensorTest = T3kMultiDeviceFixture;

TEST_F(MeshTensorTest, Lifecycle) {
const TensorSpec tensor_spec =
TensorSpec(ttnn::Shape{1, 1, 32, 32}, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));
Expand Down Expand Up @@ -56,9 +57,7 @@ TEST_F(MeshTensorTest, Lifecycle) {
EXPECT_FALSE(input_tensor.is_allocated());
}

using MeshTensorDeviceTest = T3kMultiDeviceFixture;

TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) {
TEST_F(MeshTensorTest, ToHostNonMeshTensor) {
const ttnn::Shape shape{1, 1, 32, 32};
const TensorSpec tensor_spec =
TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));
Expand All @@ -68,7 +67,7 @@ TEST_F(MeshTensorDeviceTest, ToHostNonMeshTensor) {
EXPECT_ANY_THROW(tensor_impl::to_host_mesh_tensor_wrapper(input_host_tensor));
}

TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) {
TEST_F(MeshTensorTest, ReplicateOwnedTensor) {
const ttnn::Shape shape{1, 1, 32, 32};
const TensorSpec tensor_spec =
TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));
Expand Down Expand Up @@ -105,78 +104,196 @@ TEST_F(MeshTensorDeviceTest, ReplicateHostTensor) {
}
}

TEST_F(MeshTensorDeviceTest, WriteMultiDeviceHostTensor) {
struct MeshTensorWriteTestParams {
ttnn::Shape shape;
bool use_pre_allocated_tensor = false;
std::vector<ttnn::Shape> expected_shapes;
std::vector<distributed::MeshCoordinate> expected_coords;
std::function<std::unique_ptr<ttnn::distributed::TensorToMesh>(MeshDevice*)> get_mapper;
};

class MeshTensorWriteTest : public T3kMultiDeviceFixture,
public ::testing::WithParamInterface<MeshTensorWriteTestParams> {};

TEST_P(MeshTensorWriteTest, WriteMultiDeviceHostTensor) {
const int num_devices = mesh_device_->num_devices();
ASSERT_EQ(num_devices, 8);

// Test uneven shard shapes.
const ttnn::Shape shape{1, 9, 32, 32};
const auto expected_shapes_matcher = ElementsAre(
Eq(ttnn::Shape{1, 2, 32, 32}),
Eq(ttnn::Shape{1, 2, 32, 32}),
Eq(ttnn::Shape{1, 2, 32, 32}),
Eq(ttnn::Shape{1, 2, 32, 32}),
Eq(ttnn::Shape{1, 1, 32, 32}));
const ttnn::Shape shape = GetParam().shape;

std::vector<::testing::Matcher<ttnn::Shape>> shape_matchers;
for (const auto& expected_shape : GetParam().expected_shapes) {
shape_matchers.push_back(Eq(expected_shape));
}

std::vector<::testing::Matcher<distributed::MeshCoordinate>> coord_matchers;
for (const auto& expected_coord : GetParam().expected_coords) {
coord_matchers.push_back(Eq(expected_coord));
}

const auto mapper = GetParam().get_mapper(mesh_device_.get());

// Prepare multi-device host tensor to offload on device.
const TensorSpec tensor_spec =
TensorSpec(shape, TensorLayout(DataType::FLOAT32, Layout::ROW_MAJOR, MemoryConfig{}));

std::vector<float> host_data(shape.volume());
std::iota(host_data.begin(), host_data.end(), 0);

// Prepare multi-device host tensor to offload on device.
Tensor input_host_tensor_sharded =
distribute_tensor(Tensor::from_vector(host_data, tensor_spec), *shard_tensor_to_mesh_mapper(*mesh_device_, 1));
Tensor input_host_tensor_sharded = distribute_tensor(Tensor::from_vector(host_data, tensor_spec), *mapper);
EXPECT_TRUE(input_host_tensor_sharded.storage_type() == StorageType::MULTI_DEVICE_HOST);
std::vector<Tensor> input_host_shards = get_device_tensors(input_host_tensor_sharded);

auto* multi_device_host_storage =
std::get_if<tt::tt_metal::MultiDeviceHostStorage>(&input_host_tensor_sharded.get_storage());
ASSERT_NE(multi_device_host_storage, nullptr);
const auto* strategy = std::get_if<tt::tt_metal::ShardTensor>(&multi_device_host_storage->strategy);
ASSERT_NE(strategy, nullptr);
EXPECT_EQ(strategy->shard_dimension, 1);
EXPECT_EQ(multi_device_host_storage->strategy, mapper->config());

auto device_tensor = [&]() {
if (GetParam().use_pre_allocated_tensor) {
Tensor device_tensor =
allocate_tensor_on_mesh(input_host_shards.at(0).get_tensor_spec(), mesh_device_.get());
write_tensor(input_host_tensor_sharded, device_tensor);
return device_tensor;
} else {
return tensor_impl::to_device_mesh_tensor_wrapper(
input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{});
}
}();

// Write host tensor to device.
Tensor device_tensor =
tensor_impl::to_device_mesh_tensor_wrapper(input_host_tensor_sharded, mesh_device_.get(), MemoryConfig{});
EXPECT_TRUE(distributed::is_mesh_buffer_tensor(device_tensor));

auto* device_storage = std::get_if<tt::tt_metal::DeviceStorage>(&device_tensor.get_storage());
ASSERT_NE(device_storage, nullptr);
std::vector<distributed::MeshCoordinate> coords;
std::vector<ttnn::Shape> shapes;
EXPECT_EQ(device_storage->strategy, mapper->config());

std::vector<distributed::MeshCoordinate> device_shard_coords;
std::vector<ttnn::Shape> device_shard_shapes;
for (const auto& [coord, spec] : device_storage->specs) {
coords.push_back(coord);
shapes.push_back(spec.logical_shape());
device_shard_coords.push_back(coord);
device_shard_shapes.push_back(spec.logical_shape());
}
EXPECT_THAT(shapes, expected_shapes_matcher);
EXPECT_THAT(
coords,
ElementsAre(
Eq(distributed::MeshCoordinate{0, 0}),
Eq(distributed::MeshCoordinate{0, 1}),
Eq(distributed::MeshCoordinate{0, 2}),
Eq(distributed::MeshCoordinate{0, 3}),
Eq(distributed::MeshCoordinate{1, 0})));
EXPECT_THAT(device_shard_shapes, ElementsAreArray(shape_matchers));
EXPECT_THAT(device_shard_coords, ElementsAreArray(coord_matchers));

// Read the tensor back, and compare it with input data.
auto host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor);
auto* owned_storage = std::get_if<tt::tt_metal::MultiDeviceHostStorage>(&host_tensor.get_storage());
ASSERT_NE(owned_storage, nullptr);
std::vector<ttnn::Shape> host_shapes;
for (const auto& spec : owned_storage->specs) {
host_shapes.push_back(spec.logical_shape());
auto output_host_tensor = tensor_impl::to_host_mesh_tensor_wrapper(device_tensor);
auto* output_multi_device_host_storage =
std::get_if<tt::tt_metal::MultiDeviceHostStorage>(&output_host_tensor.get_storage());
ASSERT_NE(output_multi_device_host_storage, nullptr);
EXPECT_EQ(output_multi_device_host_storage->strategy, mapper->config());
std::vector<ttnn::Shape> output_host_shapes;
for (const auto& spec : output_multi_device_host_storage->specs) {
output_host_shapes.push_back(spec.logical_shape());
}
EXPECT_THAT(host_shapes, expected_shapes_matcher);
EXPECT_THAT(output_host_shapes, ElementsAreArray(shape_matchers));

Tensor aggregated_host_tensor = aggregate_tensor(host_tensor, *concat_mesh_to_tensor_composer(1));
EXPECT_TRUE(aggregated_host_tensor.storage_type() == StorageType::OWNED);
EXPECT_EQ(aggregated_host_tensor.get_tensor_spec().logical_shape(), shape);
std::vector<Tensor> output_host_shards = get_device_tensors(output_host_tensor);
ASSERT_EQ(output_host_shards.size(), input_host_shards.size());
for (int i = 0; i < output_host_shards.size(); i++) {
EXPECT_THAT(
output_host_shards[i].to_vector<float>(), Pointwise(FloatEq(), input_host_shards[i].to_vector<float>()));
}
}

EXPECT_THAT(aggregated_host_tensor.to_vector<float>(), Pointwise(FloatEq(), host_data));
// Returns a vector of `MeshTensorWriteTestParams`, with and without `use_pre_allocated_tensor` set to true.
auto get_mesh_tensor_write_test_params() {
std::vector<MeshTensorWriteTestParams> base_params = {
MeshTensorWriteTestParams{
.shape = ttnn::Shape{1, 8, 32, 32},
.expected_shapes =
{ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32}},
.expected_coords =
{distributed::MeshCoordinate{0, 0},
distributed::MeshCoordinate{0, 1},
distributed::MeshCoordinate{0, 2},
distributed::MeshCoordinate{0, 3},
distributed::MeshCoordinate{1, 0},
distributed::MeshCoordinate{1, 1},
distributed::MeshCoordinate{1, 2},
distributed::MeshCoordinate{1, 3}},
.get_mapper = [](MeshDevice* device) { return shard_tensor_to_mesh_mapper(*device, 1); },
},
MeshTensorWriteTestParams{
.shape = ttnn::Shape{1, 9, 32, 32},
.expected_shapes =
{ttnn::Shape{1, 2, 32, 32},
ttnn::Shape{1, 2, 32, 32},
ttnn::Shape{1, 2, 32, 32},
ttnn::Shape{1, 2, 32, 32},
ttnn::Shape{1, 1, 32, 32}},
.expected_coords =
{distributed::MeshCoordinate{0, 0},
distributed::MeshCoordinate{0, 1},
distributed::MeshCoordinate{0, 2},
distributed::MeshCoordinate{0, 3},
distributed::MeshCoordinate{1, 0}},
.get_mapper = [](MeshDevice* device) { return shard_tensor_to_mesh_mapper(*device, 1); },
},
MeshTensorWriteTestParams{
.shape = ttnn::Shape{1, 1, 32, 32},
.expected_shapes =
{ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32},
ttnn::Shape{1, 1, 32, 32}},
.expected_coords =
{distributed::MeshCoordinate{0, 0},
distributed::MeshCoordinate{0, 1},
distributed::MeshCoordinate{0, 2},
distributed::MeshCoordinate{0, 3},
distributed::MeshCoordinate{1, 0},
distributed::MeshCoordinate{1, 1},
distributed::MeshCoordinate{1, 2},
distributed::MeshCoordinate{1, 3}},
.get_mapper = [](MeshDevice* device) { return replicate_tensor_to_mesh_mapper(*device); },
},
MeshTensorWriteTestParams{
.shape = ttnn::Shape{7, 3, 32, 32},
.expected_shapes =
{ttnn::Shape{7, 1, 32, 32},
ttnn::Shape{7, 1, 32, 32},
ttnn::Shape{7, 1, 32, 32},
ttnn::Shape{7, 1, 32, 32},
ttnn::Shape{7, 1, 32, 32},
ttnn::Shape{7, 1, 32, 32}},
.expected_coords =
{distributed::MeshCoordinate{0, 0},
distributed::MeshCoordinate{0, 1},
distributed::MeshCoordinate{0, 2},
distributed::MeshCoordinate{1, 0},
distributed::MeshCoordinate{1, 1},
distributed::MeshCoordinate{1, 2}},
.get_mapper =
[](MeshDevice* device) {
// Replicate to a submesh 2x3
// Replicate within each row, then split by second dimension.
return shard_tensor_to_2d_mesh_mapper(*device, MeshShape{2, 3}, Shard2dConfig{std::nullopt, 1});
},
},
};

std::vector<MeshTensorWriteTestParams> params;
for (auto param : base_params) {
param.use_pre_allocated_tensor = false;
params.push_back(param);
param.use_pre_allocated_tensor = true;
params.push_back(param);
}
return params;
}

// TODO: add tests for copying to mesh tensor
INSTANTIATE_TEST_SUITE_P(
MeshTensorWriteTest, MeshTensorWriteTest, ::testing::ValuesIn(get_mesh_tensor_write_test_params()));

} // namespace
} // namespace ttnn::distributed::test
24 changes: 12 additions & 12 deletions ttnn/cpp/ttnn/tensor/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ Tensor to_device_mesh_tensor(
}

template <typename T>
void copy_to_mesh_tensor(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id) {
void copy_to_mesh_tensor(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id) {
TT_FATAL(host_tensor.storage_type() != StorageType::DEVICE, "Host tensor is on device.");
TT_FATAL(mesh_tensor.storage_type() == StorageType::DEVICE, "Mesh tensor is not on device.");
TT_FATAL(tt::tt_metal::detail::InMainThread(), "copy_to_mesh_tensor must be called from the main thread");
Expand All @@ -858,9 +858,6 @@ void copy_to_mesh_tensor(const Tensor& host_tensor, const Tensor& mesh_tensor, t
host_tensor.get_tensor_spec().page_config() == mesh_tensor.get_tensor_spec().page_config(),
"Host tensor has different page config");

// TODO: when copying from multi-device host storage, verify that tensor spec on each piece matches.
// Or should we verify that the tensor simply fits into pre-allocated device tensor?

const auto& tensor_spec = mesh_tensor.get_tensor_spec();
auto mesh_buffer = std::get<DeviceStorage>(mesh_tensor.get_storage()).mesh_buffer;
auto* mesh_device = mesh_buffer->device();
Expand All @@ -878,6 +875,9 @@ void copy_to_mesh_tensor(const Tensor& host_tensor, const Tensor& mesh_tensor, t
},
[](const auto& s) -> DeviceStorage { TT_THROW("Unexpected storage type {}", tt::stl::get_type_name(s)); }},
host_tensor.get_storage());

// Set storage with the populated metadata.
mesh_tensor.set_storage(mesh_storage);
}

template Tensor to_device_mesh_tensor<bfloat16>(
Expand Down Expand Up @@ -929,20 +929,20 @@ Tensor to_device_mesh_tensor<bfloat8_b>(
return to_device_mesh_tensor<uint32_t>(tensor, target_device, memory_config, cq_id);
}

template void copy_to_mesh_tensor<bfloat16>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<float>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<int32_t>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<uint32_t>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<uint16_t>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<uint8_t>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<bfloat16>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<float>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<int32_t>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<uint32_t>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<uint16_t>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id);
template void copy_to_mesh_tensor<uint8_t>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id);

template <>
void copy_to_mesh_tensor<bfloat4_b>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id) {
void copy_to_mesh_tensor<bfloat4_b>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id) {
copy_to_mesh_tensor<uint32_t>(host_tensor, mesh_tensor, cq_id);
}

template <>
void copy_to_mesh_tensor<bfloat8_b>(const Tensor& host_tensor, const Tensor& mesh_tensor, ttnn::QueueId cq_id) {
void copy_to_mesh_tensor<bfloat8_b>(const Tensor& host_tensor, Tensor& mesh_tensor, ttnn::QueueId cq_id) {
copy_to_mesh_tensor<uint32_t>(host_tensor, mesh_tensor, cq_id);
}

Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ Tensor to_device_mesh_tensor(
QueueId cq_id = ttnn::DefaultQueueId);

template <typename T>
void copy_to_mesh_tensor(const Tensor& host_tensor, const Tensor& mesh_tensor, QueueId cq_id = ttnn::DefaultQueueId);
void copy_to_mesh_tensor(const Tensor& host_tensor, Tensor& mesh_tensor, QueueId cq_id = ttnn::DefaultQueueId);

// ======================================================================================
// .to_layout()
Expand Down

0 comments on commit 7538501

Please sign in to comment.