From 1001e959e749425df838c0e8d4a15192f7975907 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 7 Feb 2025 18:53:04 +0000 Subject: [PATCH 01/76] expose classes to python --- .../ttnn/distributed/distributed_pybind.cpp | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 9ad24cf4aee..0356eb025ac 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -3,12 +3,15 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/distributed/distributed_pybind.hpp" +#include +#include #include #include #include #include "tt-metalium/mesh_coord.hpp" +#include "distributed_tensor.hpp" #include "ttnn/distributed/api.hpp" #include "ttnn/distributed/types.hpp" #include "ttnn/tensor/tensor.hpp" @@ -25,6 +28,14 @@ namespace ttnn::distributed { namespace py = pybind11; void py_module_types(py::module& module) { + py::class_>(module, "MeshToTensor"); + py::class_>(module, "TensorToMesh"); + py::class_(module, "TensorToMesh"); + py::class_(module, "ShardTensorToMesh"); + py::class_(module, "ShardTensorTo2dMesh"); + py::class_(module, "ConcatMeshToTensor"); + py::class_(module, "Concat2dMeshToTensor"); + py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); py::class_(module, "MeshShape", "Shape of a mesh device."); @@ -360,6 +371,94 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); + auto py_tensor_to_mesh = static_cast>>(module.attr("TensorToMesh")); + py_tensor_to_mesh + .def(py::init<>(MeshDevice & mesh_device), + py::kw_only(), + py::arg("mesh_device")) + .def("map", &TensorToMesh::map) + .def("config", &TensorToMesh::config); + + auto py_replicate_tensor_to_mesh = static_cast>( + module.attr("ReplicateTensorToMesh")); + py_replicate_tensor_to_mesh + .def(py::init<>(MeshDevice & mesh_device) { + return replicate_tensor_to_mesh_mapper(mesh_device); + }, + py::kw_only(), + py::arg("mesh_device")) + .def(py::init<>() + py::kw_only()) + .def("map",[](self, const Tensor& tensor) { + return self.map(tensor); + }, + py::arg("tensor") + .def("config", &ReplicateTensorToMesh::config); + + auto py_shard_tensor_to_mesh = static_cast>( + module.attr("ShardTensorToMesh")); + py_shard_tensor_to_mesh + .def(py::init<>(MeshDevice & mesh_device, int dim) { + return shard_tensor_to_mesh_mapper(mesh_device, dim); + }, + py::kw_only(), + py::arg("mesh_device"), + py::arg("dim")) + .def(py::init<>() + py::kw_only()) + .def("map",[](self, const Tensor& tensor) { + return self.map(tensor); + }, + py::arg("tensor")) + .def("config", &ShardTensorToMesh::config); + + auto py_shard_tensor_to_2d_mesh = static_cast>(module.attr("ShardTensorTo2dMesh")); + py_shard_tensor_to_2d_mesh + .def(py::init<>(MeshDevice & mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) { + return shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + }, + py::kw_only(), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")) + .def(py::init<>() + py::kw_only()) + .def("map",[](self, const Tensor& tensor) { + return self.map(tensor); + }, + py::arg("tensor")) + .def("config", &ShardTensorTo2dMesh::config); + + auto py_mesh_to_tensor = static_cast>>(module.attr("MeshToTensor")); + py_mesh_to_tensor + .def(py::init<>) + .def("compose", &MeshToTensor::compose); + + auto py_concat_mesh_to_tensor = static_cast>(module.attr("ConcatMeshToTensor")); + py_concat_mesh_to_tensor + .def(py::init<>(int dim) { + return concat_mesh_to_tensor_composer(dim); + }, + py::kw_only(), + py::arg("dim")) + .def("compose",[](self, const std::vector& tensors) { + return self.compose(tensors); + }, + py::arg("tensors")); + + auto py_concat_2d_mesh_to_tensor = static_cast>(module.attr("Concat2dMeshToTensor")); + py_concat_2d_mesh_to_tensor + .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { + return concat_2d_mesh_to_tensor_composer(mesh_device, config); + }, + py::kw_only(), + py::arg("mesh_device"), + py::arg("config")) + .def("compose",[](self, const std::vector& tensors) { + return self.compose(tensors); + }, + .py::arg("tensors")); + module.def( "open_mesh_device", &open_mesh_device, @@ -410,11 +509,13 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); + //TODO: overload this method to enable selection of a subset of shards with a config or something before passing to aggregate module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, py::arg("tensors"), py::kw_only()); + module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } From c80d173a3859860dbaf1ddbc51714368dfa433c5 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Sat, 8 Feb 2025 00:23:07 +0000 Subject: [PATCH 02/76] one type error left --- .../ttnn/distributed/distributed_pybind.cpp | 216 +++++++++++------- ttnn/ttnn/distributed/__init__.py | 1 - 2 files changed, 138 insertions(+), 79 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 0356eb025ac..29efb3f0739 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/distributed/distributed_pybind.hpp" +#include #include #include #include @@ -12,8 +13,14 @@ #include #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" +#include "distributed_tensor.cpp" #include "ttnn/distributed/api.hpp" +<<<<<<< HEAD #include "ttnn/distributed/types.hpp" +======= +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +>>>>>>> one type error left #include "ttnn/tensor/tensor.hpp" #include "ttnn/types.hpp" @@ -27,14 +34,36 @@ namespace ttnn::distributed { namespace py = pybind11; +// Trampoline class to clear virtual method errors +struct ConcreteTensorToMesh : TensorToMesh { + using TensorToMesh::TensorToMesh; // Inherit constructors + + std::vector map(const Tensor& tensor) const override { + PYBIND11_OVERRIDE(std::vector, TensorToMesh, map, tensor); + } + + tt::tt_metal::DistributedTensorConfig config() const override { + PYBIND11_OVERRIDE(tt::tt_metal::DistributedTensorConfig, TensorToMesh, config); + } +}; + +// Trampoline class to clear virtual method errors +struct ConcreteMeshToTensor : MeshToTensor { + Tensor compose(const std::vector& tensors) const override { + PYBIND11_OVERRIDE(Tensor, MeshToTensor, compose, tensors); + } +}; + void py_module_types(py::module& module) { - py::class_>(module, "MeshToTensor"); - py::class_>(module, "TensorToMesh"); - py::class_(module, "TensorToMesh"); - py::class_(module, "ShardTensorToMesh"); - py::class_(module, "ShardTensorTo2dMesh"); - py::class_(module, "ConcatMeshToTensor"); - py::class_(module, "Concat2dMeshToTensor"); + py::class_>(module, "MeshToTensor"); + py::class_>(module, "TensorToMesh"); + py::class_>( + module, "ReplicateTensorToMesh"); + py::class_>(module, "ShardTensorToMesh"); + py::class_>(module, "ShardTensorTo2dMesh"); + py::class_>(module, "ConcatMeshToTensor"); + py::class_>( + module, "Concat2dMeshToTensor"); py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); @@ -371,93 +400,124 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); - auto py_tensor_to_mesh = static_cast>>(module.attr("TensorToMesh")); + auto py_tensor_to_mesh = + static_cast>>(module.attr("TensorToMesh")); py_tensor_to_mesh - .def(py::init<>(MeshDevice & mesh_device), - py::kw_only(), - py::arg("mesh_device")) + .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); - auto py_replicate_tensor_to_mesh = static_cast>( - module.attr("ReplicateTensorToMesh")); + auto py_replicate_tensor_to_mesh = + static_cast>>( + module.attr("ReplicateTensorToMesh")); + py_replicate_tensor_to_mesh - .def(py::init<>(MeshDevice & mesh_device) { - return replicate_tensor_to_mesh_mapper(mesh_device); - }, - py::kw_only(), - py::arg("mesh_device")) - .def(py::init<>() - py::kw_only()) - .def("map",[](self, const Tensor& tensor) { - return self.map(tensor); - }, - py::arg("tensor") + .def( + py::init([](MeshDevice& mesh_device) -> std::unique_ptr { + return ttnn::distributed::replicate_tensor_to_mesh_mapper(mesh_device); + }), + py::kw_only(), + py::arg("mesh_device")) + .def( + py::init([](size_t num_devices) -> std::unique_ptr { + return std::make_unique(ReplicateTensorToMesh(num_devices)); + }), + py::kw_only(), + py::arg("num_devices")) + .def( + "map", + [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, + py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); - auto py_shard_tensor_to_mesh = static_cast>( - module.attr("ShardTensorToMesh")); + auto py_shard_tensor_to_mesh = static_cast>>( + module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh - .def(py::init<>(MeshDevice & mesh_device, int dim) { - return shard_tensor_to_mesh_mapper(mesh_device, dim); - }, - py::kw_only(), - py::arg("mesh_device"), - py::arg("dim")) - .def(py::init<>() - py::kw_only()) - .def("map",[](self, const Tensor& tensor) { - return self.map(tensor); - }, - py::arg("tensor")) + .def( + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return ttnn::distributed::shard_tensor_to_mesh_mapper(mesh_device, dim); + }), + py::kw_only(), + py::arg("mesh_device"), + py::arg("dim")) + .def( + py::init([](size_t num_devices, int dim) -> std::unique_ptr { + return std::make_unique(ShardTensorToMesh(num_devices, dim)); + }), + py::kw_only(), + py::arg("num_devices"), + py::arg("dim")) + .def( + "map", + [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, + py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = static_cast>(module.attr("ShardTensorTo2dMesh")); + auto py_shard_tensor_to_2d_mesh = + static_cast>>( + module.attr("ShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh - .def(py::init<>(MeshDevice & mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) { - return shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); - }, - py::kw_only(), - py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("config")) - .def(py::init<>() - py::kw_only()) - .def("map",[](self, const Tensor& tensor) { - return self.map(tensor); - }, - py::arg("tensor")) + .def( + py::init( + [](MeshDevice& mesh_device, + const MeshShape& mesh_shape, + const Shard2dConfig& config) -> std::unique_ptr { + return ttnn::distributed::shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + }), + py::kw_only(), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")) + .def( + py::init([](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); + }), + py::kw_only(), + py::arg("mesh_shape"), + py::arg("config")) + .def( + "map", + [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + py::arg("tensor")) .def("config", &ShardTensorTo2dMesh::config); - auto py_mesh_to_tensor = static_cast>>(module.attr("MeshToTensor")); + auto py_mesh_to_tensor = + static_cast>>(module.attr("MeshToTensor")); py_mesh_to_tensor - .def(py::init<>) + .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); - auto py_concat_mesh_to_tensor = static_cast>(module.attr("ConcatMeshToTensor")); + auto py_concat_mesh_to_tensor = static_cast>>( + module.attr("ConcatMeshToTensor")); py_concat_mesh_to_tensor - .def(py::init<>(int dim) { - return concat_mesh_to_tensor_composer(dim); - }, - py::kw_only(), - py::arg("dim")) - .def("compose",[](self, const std::vector& tensors) { - return self.compose(tensors); - }, - py::arg("tensors")); - - auto py_concat_2d_mesh_to_tensor = static_cast>(module.attr("Concat2dMeshToTensor")); + .def( + py::init([](int dim) -> std::unique_ptr { + return ttnn::distributed::concat_mesh_to_tensor_composer(dim); + }), + py::kw_only(), + py::arg("dim")) + .def( + "compose", + [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, + py::arg("tensors")); + + auto py_concat_2d_mesh_to_tensor = + static_cast>>( + module.attr("Concat2dMeshToTensor")); py_concat_2d_mesh_to_tensor - .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { - return concat_2d_mesh_to_tensor_composer(mesh_device, config); - }, - py::kw_only(), - py::arg("mesh_device"), - py::arg("config")) - .def("compose",[](self, const std::vector& tensors) { - return self.compose(tensors); - }, - .py::arg("tensors")); + .def( + py::init([](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + return ttnn::distributed::concat_2d_mesh_to_tensor_composer(mesh_device, config); + }), + py::kw_only(), + py::arg("mesh_device"), + py::arg("config")) + .def( + "compose", + [](Concat2dMeshToTensor self, const std::vector& tensors) -> Tensor { + return self.compose(tensors); + }, + py::arg("tensors")); module.def( "open_mesh_device", @@ -509,13 +569,13 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); - //TODO: overload this method to enable selection of a subset of shards with a config or something before passing to aggregate + // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to + // aggregate module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, py::arg("tensors"), py::kw_only()); - module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index bc90ce3cf20..1e566d85567 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -19,6 +19,5 @@ MeshToTensor, ConcatMeshToTensor, visualize_mesh_device, - ConcatMesh2dToTensor, distribute, ) From 8f59edcf6e9e49f6b34a5b153b77dd9d3dabf733 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Mon, 10 Feb 2025 17:54:11 +0000 Subject: [PATCH 03/76] move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors --- .../ttnn/distributed/distributed_pybind.cpp | 82 ++++++--- .../ttnn/distributed/distributed_tensor.cpp | 8 +- .../ttnn/distributed/distributed_tensor.hpp | 171 +++++++++++++++++- 3 files changed, 225 insertions(+), 36 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 29efb3f0739..689af415099 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -13,7 +13,6 @@ #include #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" -#include "distributed_tensor.cpp" #include "ttnn/distributed/api.hpp" <<<<<<< HEAD #include "ttnn/distributed/types.hpp" @@ -400,8 +399,8 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); - auto py_tensor_to_mesh = - static_cast>>(module.attr("TensorToMesh")); + auto py_tensor_to_mesh = static_cast>>( + module.attr("TensorToMesh")); py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) @@ -413,13 +412,13 @@ void py_module(py::module& module) { py_replicate_tensor_to_mesh .def( - py::init([](MeshDevice& mesh_device) -> std::unique_ptr { - return ttnn::distributed::replicate_tensor_to_mesh_mapper(mesh_device); + py::init([](MeshDevice& mesh_device) -> std::unique_ptr { + return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); }), py::kw_only(), py::arg("mesh_device")) .def( - py::init([](size_t num_devices) -> std::unique_ptr { + py::init([](size_t num_devices) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(num_devices)); }), py::kw_only(), @@ -434,14 +433,14 @@ void py_module(py::module& module) { module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return ttnn::distributed::shard_tensor_to_mesh_mapper(mesh_device, dim); + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return std::make_unique(ShardTensorToMesh(mesh_device, dim)); }), py::kw_only(), py::arg("mesh_device"), py::arg("dim")) .def( - py::init([](size_t num_devices, int dim) -> std::unique_ptr { + py::init([](size_t num_devices, int dim) -> std::unique_ptr { return std::make_unique(ShardTensorToMesh(num_devices, dim)); }), py::kw_only(), @@ -461,17 +460,18 @@ void py_module(py::module& module) { py::init( [](MeshDevice& mesh_device, const MeshShape& mesh_shape, - const Shard2dConfig& config) -> std::unique_ptr { - return ttnn::distributed::shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensorTo2dMesh(mesh_device, mesh_shape, config)); }), py::kw_only(), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")) .def( - py::init([](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); - }), + py::init( + [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); + }), py::kw_only(), py::arg("mesh_shape"), py::arg("config")) @@ -481,8 +481,8 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ShardTensorTo2dMesh::config); - auto py_mesh_to_tensor = - static_cast>>(module.attr("MeshToTensor")); + auto py_mesh_to_tensor = static_cast>>( + module.attr("MeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); @@ -491,8 +491,8 @@ void py_module(py::module& module) { module.attr("ConcatMeshToTensor")); py_concat_mesh_to_tensor .def( - py::init([](int dim) -> std::unique_ptr { - return ttnn::distributed::concat_mesh_to_tensor_composer(dim); + py::init([](int dim) -> std::unique_ptr { + return std::make_unique(dim); }), py::kw_only(), py::arg("dim")) @@ -506,9 +506,15 @@ void py_module(py::module& module) { module.attr("Concat2dMeshToTensor")); py_concat_2d_mesh_to_tensor .def( - py::init([](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { - return ttnn::distributed::concat_2d_mesh_to_tensor_composer(mesh_device, config); - }), + py::init( + [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + TT_FATAL( + config.row_dim != config.col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + config.row_dim, + config.col_dim); + return std::make_unique(mesh_device, config); + }), py::kw_only(), py::arg("mesh_device"), py::arg("config")) @@ -569,6 +575,40 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); + module.def( + "replicate_tensor_to_mesh_mapper", + [](MeshDevice& mesh_device) -> std::unique_ptr { + return replicate_tensor_to_mesh_mapper(mesh_device); + }, + py::arg("mesh_device")); + module.def( + "shard_tensor_to_mesh_mapper", + [](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return shard_tensor_to_mesh_mapper(mesh_device, dim); + }, + py::arg("mesh_device"), + py::arg("dim")); + module.def( + "shard_tensor_to_2d_mesh_mapper", + [](MeshDevice& mesh_device, + const MeshShape& mesh_shape, + const Shard2dConfig& config) -> std::unique_ptr { + return shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + }, + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")); + module.def( + "concat_mesh_to_tensor_composer", + [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, + py::arg("dim")); + module.def( + "concat_2d_mesh_to_tensor_composer", + [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + return concat_2d_mesh_to_tensor_composer(mesh_device, config); + }, + py::arg("mesh_device"), + py::arg("config")); // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to // aggregate module.def( diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index af3cf6d1fbf..1175acf6422 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -2,14 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 -#include - -#include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor.hpp" -#include -#include "ttnn/distributed/distributed_tensor_config.hpp" -#include "ttnn/distributed/types.hpp" -#include "ttnn/tensor/xtensor/partition.hpp" +#include "tt-metalium/assert.hpp" namespace ttnn::distributed { namespace { diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 7d49ca932f4..c12a51ac3d7 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -6,6 +6,12 @@ #include "ttnn/tensor/tensor.hpp" #include "ttnn/distributed/types.hpp" +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/distributed/types.hpp" +#include "ttnn/tensor/xtensor/partition.hpp" +#include +#include namespace ttnn::distributed { @@ -24,6 +30,162 @@ class MeshToTensor { virtual Tensor compose(const std::vector& tensors) const = 0; }; +struct Shard2dConfig { + std::optional row_dim; + std::optional col_dim; +}; + +struct Concat2dConfig { + int row_dim = -1; + int col_dim = -1; +}; + +class ReplicateTensorToMesh : public TensorToMesh { +public: + ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {} + + ReplicateTensorToMesh(MeshDevice& mesh_device) : num_devices_(mesh_device.num_devices()) {} + + std::vector map(const Tensor& tensor) const override { + std::vector tensors; + tensors.reserve(num_devices_); + std::fill_n(std::back_inserter(tensors), num_devices_, tensor); + return tensors; + } + + tt::tt_metal::DistributedTensorConfig config() const override { + return tt::tt_metal::DistributedTensorConfig{ReplicateTensor{num_devices_}}; + } + +private: + size_t num_devices_ = 0; +}; + +class ShardTensorToMesh : public TensorToMesh { +public: + ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} + + ShardTensorToMesh(MeshDevice& mesh_device, int dim) : num_devices_(mesh_device.num_devices()), shard_dim_(dim) {} + + std::vector map(const Tensor& tensor) const override { + return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); + } + + tt::tt_metal::DistributedTensorConfig config() const override { + return tt::tt_metal::DistributedTensorConfig{ShardTensor{shard_dim_}}; + } + +private: + size_t num_devices_ = 0; + int shard_dim_ = -1; +}; + +class ShardTensorTo2dMesh : public TensorToMesh { +public: + ShardTensorTo2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : + mesh_shape_(mesh_shape), config_(config) { + TT_FATAL( + config.row_dim.has_value() || config.col_dim.has_value(), + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); + TT_FATAL( + mesh_shape.num_rows <= mesh_device.shape().num_rows && // + mesh_shape.num_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + } + + ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : + mesh_shape_(mesh_shape), config_(config) {} + + std::vector map(const Tensor& tensor) const override { + const auto [rows, cols] = mesh_shape_; + const auto [row_dim, col_dim] = config_; + + std::vector row_tensors; + + // Shard along rows + if (!row_dim.has_value()) { + row_tensors.reserve(rows); + for (int i = 0; i < rows; ++i) { + row_tensors.push_back(tensor); + } + } else { + row_tensors = experimental::xtensor::chunk(tensor, rows, *row_dim); + } + + std::vector tensor_shards; + tensor_shards.reserve(rows * cols); + // Shard along columns + if (!col_dim.has_value()) { + for (const auto& t : row_tensors) { + for (int i = 0; i < cols; ++i) { + tensor_shards.push_back(t); + } + } + } else { + for (const auto& t : row_tensors) { + auto col_chunks = experimental::xtensor::chunk(t, cols, *col_dim); + tensor_shards.insert(tensor_shards.end(), col_chunks.begin(), col_chunks.end()); + } + } + + TT_FATAL( + static_cast(tensor_shards.size()) == rows * cols, + "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "dimensions. Size: {}, rows: {}, cols: {}", + tensor_shards.size(), + rows, + cols); + + return tensor_shards; + } + + tt::tt_metal::DistributedTensorConfig config() const override { + return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; + } + +private: + MeshShape mesh_shape_; + Shard2dConfig config_; +}; + +class ConcatMeshToTensor : public MeshToTensor { +public: + ConcatMeshToTensor(int dim) : concat_dim_(dim) {} + + Tensor compose(const std::vector& tensors) const override { + return experimental::xtensor::concat(tensors, concat_dim_); + } + +private: + int concat_dim_ = -1; +}; + +class Concat2dMeshToTensor : public MeshToTensor { +public: + Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : + mesh_shape_(mesh_device.shape()), config_(config) {} + + Tensor compose(const std::vector& tensors) const override { + const auto [rows, cols] = mesh_shape_; + const auto [row_dim, col_dim] = config_; + + std::vector row_concatenated; + row_concatenated.reserve(rows); + for (int i = 0; i < rows; ++i) { + auto row_start = tensors.begin() + i * cols; + auto row_end = row_start + cols; + std::vector row_tensors(row_start, row_end); + row_concatenated.push_back(experimental::xtensor::concat(row_tensors, col_dim)); + } + + return experimental::xtensor::concat(row_concatenated, row_dim); + } + +private: + MeshShape mesh_shape_; + Concat2dConfig config_; +}; + // Creates a mapper that replicates a tensor across all devices. std::unique_ptr replicate_tensor_to_mesh_mapper(MeshDevice& mesh_device); @@ -32,10 +194,6 @@ std::unique_ptr shard_tensor_to_mesh_mapper(MeshDevice& mesh_devic // Creates a mapper that shards a tensor along two dimensions, which will be intepreted as rows and columns. // If either dimension is not specified, the tensor is replicated along that dimension. -struct Shard2dConfig { - std::optional row_dim; - std::optional col_dim; -}; std::unique_ptr shard_tensor_to_2d_mesh_mapper( MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config); @@ -43,10 +201,7 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( std::unique_ptr concat_mesh_to_tensor_composer(int dim); // Creates a composer that concatenates a tensor across two dimensions. -struct Concat2dConfig { - int row_dim = -1; - int col_dim = -1; -}; + std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config); // Distributes a host tensor onto multi-device configuration according to the `mapper`. From ff90ba9c9dc5bb045b8677757c2f3aab4e9614de Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Mon, 10 Feb 2025 22:54:24 +0000 Subject: [PATCH 04/76] fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice --- .../distributed/test_distributed_tensor.py | 182 +++++++++ .../ttnn/distributed/distributed_pybind.cpp | 33 +- .../ttnn/distributed/distributed_tensor.cpp | 8 + .../ttnn/distributed/distributed_tensor.hpp | 12 +- ttnn/ttnn/__init__.py | 8 + ttnn/ttnn/distributed/distributed.py | 384 +++++++++--------- 6 files changed, 421 insertions(+), 206 deletions(-) create mode 100644 tests/ttnn/distributed/test_distributed_tensor.py diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py new file mode 100644 index 00000000000..7248bf0cf63 --- /dev/null +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import typing +import pytest +import ttnn +import tempfile +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc + +from ttnn import ( + ShardTensorToMesh, + ShardTensor2dMesh, + ReplicateTensorToMesh, + ConcatMeshToTensor, + ConcatMesh2dToTensor, + MeshToTensor, + TensorToMesh, +) +from models.utility_functions import nearest_32 + + +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 32, 8192) + replicated = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + ) + + out_tensors = ttnn.get_device_tensors(mesh_device) + + test = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device) + + out_pass, out_pcc = comp_pcc(out_tensors[0], test, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +def test_shard_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 8192, 32768) + tensor_shards = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + ) + + test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) + + out_tensors = ttnn.get_device_tensors(tensor_shards) + + out_tensor = ttnn.aggregate_as_tensor(out_tensors) + + out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +def test_concat_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 8192, 32768) + sharded = ttnn.from_torch( + torch_tensor, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + ) + + test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) + + out_tensor = ttnn.to_torch( + torch_tensor, dtype=ttnn.bfloat16, mesh_composer=ttnn.ConcatMeshToTensor(dim=3), device=mesh_device + ) + + out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], +) +def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + sharded_mem_config = ttnn.create_sharded_memory_config( + shape=(M // core_grid.y, K // core_grid.x), + core_grid=core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + tensor_shards = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), + ) + + out_tensors = ttnn.get_device_tensors(tensor_shards) + + for tensor in out_tensors: + print(tensor) + + # out_pass, out_pcc = comp_pcc(tensor_shards, out, pcc=0.99) + # logger.info(f"PCC value: {out_pcc}") + # assert out_pass + + +def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + concat_dim = (3, 1) if K < N else (1, 3) # dim 1 for reduce, dim 3 for concatenating fractures + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + sharded_mem_config = ttnn.create_sharded_memory_config( + shape=(M // core_grid.y, K // core_grid.x), + core_grid=core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + tensor_shards = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), + ) + + out = ttnn.to_torch( + tensor_shards, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + ) + + out_pass, out_pcc = comp_pcc(out, torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 689af415099..42bf7a7b5e6 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -59,10 +59,10 @@ void py_module_types(py::module& module) { py::class_>( module, "ReplicateTensorToMesh"); py::class_>(module, "ShardTensorToMesh"); - py::class_>(module, "ShardTensorTo2dMesh"); + py::class_>(module, "ShardTensor2dMesh"); py::class_>(module, "ConcatMeshToTensor"); - py::class_>( - module, "Concat2dMeshToTensor"); + py::class_>( + module, "ConcatMesh2dToTensor"); py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); @@ -452,16 +452,15 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = - static_cast>>( - module.attr("ShardTensorTo2dMesh")); + auto py_shard_tensor_to_2d_mesh = static_cast>>( + module.attr("ShardTensor2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, const MeshShape& mesh_shape, - const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensorTo2dMesh(mesh_device, mesh_shape, config)); + const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); }), py::kw_only(), py::arg("mesh_device"), @@ -469,17 +468,17 @@ void py_module(py::module& module) { py::arg("config")) .def( py::init( - [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); + [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensor2dMesh(mesh_shape, config)); }), py::kw_only(), py::arg("mesh_shape"), py::arg("config")) .def( "map", - [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) - .def("config", &ShardTensorTo2dMesh::config); + .def("config", &ShardTensor2dMesh::config); auto py_mesh_to_tensor = static_cast>>( module.attr("MeshToTensor")); @@ -502,25 +501,25 @@ void py_module(py::module& module) { py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("Concat2dMeshToTensor")); + static_cast>>( + module.attr("ConcatMesh2dToTensor")); py_concat_2d_mesh_to_tensor .def( py::init( - [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { TT_FATAL( config.row_dim != config.col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", config.row_dim, config.col_dim); - return std::make_unique(mesh_device, config); + return std::make_unique(mesh_device, config); }), py::kw_only(), py::arg("mesh_device"), py::arg("config")) .def( "compose", - [](Concat2dMeshToTensor self, const std::vector& tensors) -> Tensor { + [](ConcatMesh2dToTensor self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 1175acf6422..897231fc8e6 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -161,7 +161,11 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( mesh_shape[0] <= mesh_device.shape()[0] && // mesh_shape[1] <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); +<<<<<<< HEAD return std::make_unique(mesh_shape[0], mesh_shape[1], config); +======= + return std::make_unique(mesh_shape, config); +>>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice } std::unique_ptr concat_mesh_to_tensor_composer(int dim) { @@ -174,8 +178,12 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", config.row_dim, config.col_dim); +<<<<<<< HEAD TT_FATAL(mesh_device.shape().dims() == 2, "Mesh device is not configured as a 2D mesh: {}", mesh_device.shape()); return std::make_unique(mesh_device.shape()[0], mesh_device.shape()[1], config); +======= + return std::make_unique(mesh_device, config); +>>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice } Tensor distribute_tensor( diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index c12a51ac3d7..7d45355c638 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -80,9 +80,9 @@ class ShardTensorToMesh : public TensorToMesh { int shard_dim_ = -1; }; -class ShardTensorTo2dMesh : public TensorToMesh { +class ShardTensor2dMesh : public TensorToMesh { public: - ShardTensorTo2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : + ShardTensor2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) { TT_FATAL( config.row_dim.has_value() || config.col_dim.has_value(), @@ -93,7 +93,7 @@ class ShardTensorTo2dMesh : public TensorToMesh { "Device mesh shape does not match the provided mesh shape."); } - ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : + ShardTensor2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) {} std::vector map(const Tensor& tensor) const override { @@ -130,7 +130,7 @@ class ShardTensorTo2dMesh : public TensorToMesh { TT_FATAL( static_cast(tensor_shards.size()) == rows * cols, - "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh " "dimensions. Size: {}, rows: {}, cols: {}", tensor_shards.size(), rows, @@ -160,9 +160,9 @@ class ConcatMeshToTensor : public MeshToTensor { int concat_dim_ = -1; }; -class Concat2dMeshToTensor : public MeshToTensor { +class ConcatMesh2dToTensor : public MeshToTensor { public: - Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : + ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : mesh_shape_(mesh_device.shape()), config_(config) {} Tensor compose(const std::vector& tensors) const override { diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 838a8cddd3d..87e58785cc2 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -95,6 +95,14 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( + MeshDevice, + MeshToTensor, + TensorToMesh, + ReplicateTensorToMesh, + ShardTensorToMesh, + ShardTensor2dMesh, + ConcatMeshToTensor, + ConcatMesh2dToTensor, get_device_tensor, get_device_tensors, aggregate_as_tensor, diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index db7e5f860e7..d59377181df 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -189,242 +189,260 @@ def create_mesh_device(*args, **kwargs): yield mesh_device finally: close_mesh_device(mesh_device) - - -class TensorToMesh: + +def synchronize_devices( + devices: Union["ttnn.Device", "ttnn.MeshDevice"], + queue_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], +) -> None: """ - Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. - You can also "Bring your own TensorToMesh" based on your custom mapping. - """ - - def __init__(self, mesh_device): - self.mesh_device = mesh_device + synchronize_devices(devices: Union[ttnn.Device, ttnn.MeshDevice], queue_id: Optional[int] = None, sub_device_ids: List[ttnn.SubDeviceId] = []) -> None: - def map(self, tensor: "torch.Tensor"): - raise NotImplementedError("Subclasses must implement this method") - - def config(self): - raise NotImplementedError("Subclasses must implement this method") - - -class MeshToTensor: + Synchronize the devices with host by waiting for all operations to complete. + If queue_id is provided then only the operations associated with that queue_id are waited for, + otherwise operations for all command queues are waited on. """ - Defines the inverse operation of TensorToMesh. Given a set of per-device - ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines - the mapping back to one or many torch.Tensor objects. + if isinstance(devices, ttnn.Device): + ttnn._ttnn.device.synchronize_device(devices, queue_id, sub_device_ids) + else: + for device in devices.get_device_ids(): + ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) - You can also "Bring your own MeshToTensor" based on your custom mapping. - """ - def compose(self, tensor: ttnn.Tensor): - raise NotImplementedError("Subclasses must implement this method") +# class TensorToMesh: +# """ +# Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. +# You can also "Bring your own TensorToMesh" based on your custom mapping. +# """ +# def __init__(self, mesh_device): +# self.mesh_device = mesh_device -class ShardTensorToMesh(TensorToMesh): - def __init__(self, mesh_device, dim): - super().__init__(mesh_device) - self.shard_dim = dim +# def map(self, tensor: "torch.Tensor"): +# raise NotImplementedError("Subclasses must implement this method") - def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: - import torch +# def config(self): +# raise NotImplementedError("Subclasses must implement this method") - sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) - return list(sliced_tensors) - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.shard_dim}", - } +# class MeshToTensor: +# """ +# Defines the inverse operation of TensorToMesh. Given a set of per-device +# ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines +# the mapping back to one or many torch.Tensor objects. +# You can also "Bring your own MeshToTensor" based on your custom mapping. +# """ -class ShardTensor2dMesh(TensorToMesh): - """ - Shard a tensor across a 2D mesh of devices. - - This class implements a strategy for distributing a tensor across a 2D grid of devices, - allowing for efficient parallel processing in distributed computing environments. - """ - - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): - """ - Initialize the ShardTensor2dMesh. +# def compose(self, tensor: ttnn.Tensor): +# raise NotImplementedError("Subclasses must implement this method") - Args: - mesh_device: The target device mesh for distributing the tensor. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: The dimensions to shard along, specified as (row_dim, col_dim). - The `dims` tuple determines how the tensor is sharded across the 2D mesh: - - row_dim: The dimension to shard across mesh rows (or None for replication). - - col_dim: The dimension to shard across mesh columns (or None for replication). +# class ShardTensorToMesh(TensorToMesh): +# def __init__(self, mesh_device, dim): +# super().__init__(mesh_device) +# self.shard_dim = dim - Examples: - 1. dims=(2, 3) for a tensor of shape (A, B, C, D): - - Shard along dimension 2 (C) across mesh rows - - Shard along dimension 3 (D) across mesh columns +# def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: +# import torch - 2. dims=(None, 3): - - Replicate across mesh rows - - Shard along dimension 3 (D) across mesh columns +# sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) +# return list(sliced_tensors) - 3. dims=(None, None): - - Fully replicate the tensor across all devices - """ - super().__init__(mesh_device) - self.mesh_shape: Tuple[int, int] = mesh_shape - self.dims: Tuple[Optional[int], Optional[int]] = dims +# def config(self): +# return { +# "strategy": "shard", +# "shard_dim": f"{self.shard_dim}", +# } - mesh_device_rows, mesh_device_cols = self.mesh_device.shape - if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: - raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: - """ - Map the input tensor to a list of sharded tensors. +# class ShardTensor2dMesh(TensorToMesh): +# """ +# Shard a tensor across a 2D mesh of devices. - Args: - tensor: The input tensor to be sharded. +# This class implements a strategy for distributing a tensor across a 2D grid of devices, +# allowing for efficient parallel processing in distributed computing environments. +# """ - Returns: - A list of sharded tensors, one for each device in the mesh. +# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): +# """ +# Initialize the ShardTensor2dMesh. - Raises: - ValueError: If the number of sharding dimensions is not 2. - """ - import torch +# Args: +# mesh_device: The target device mesh for distributing the tensor. +# mesh_shape: The shape of the 2D mesh as (rows, cols). +# dims: The dimensions to shard along, specified as (row_dim, col_dim). - if len(self.dims) != 2: - raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") +# The `dims` tuple determines how the tensor is sharded across the 2D mesh: +# - row_dim: The dimension to shard across mesh rows (or None for replication). +# - col_dim: The dimension to shard across mesh columns (or None for replication). + +# Examples: +# 1. dims=(2, 3) for a tensor of shape (A, B, C, D): +# - Shard along dimension 2 (C) across mesh rows +# - Shard along dimension 3 (D) across mesh columns + +# 2. dims=(None, 3): +# - Replicate across mesh rows +# - Shard along dimension 3 (D) across mesh columns + +# 3. dims=(None, None): +# - Fully replicate the tensor across all devices +# """ +# super().__init__(mesh_device) +# self.mesh_shape: Tuple[int, int] = mesh_shape +# self.dims: Tuple[Optional[int], Optional[int]] = dims + +# mesh_device_rows, mesh_device_cols = self.mesh_device.shape +# if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: +# raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") + +# def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: +# """ +# Map the input tensor to a list of sharded tensors. + +# Args: +# tensor: The input tensor to be sharded. + +# Returns: +# A list of sharded tensors, one for each device in the mesh. - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims +# Raises: +# ValueError: If the number of sharding dimensions is not 2. +# """ +# import torch - # Shard along rows - row_tensors = ( - [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) - ) +# if len(self.dims) != 2: +# raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - # Shard along columns - if col_dim is None: - return [t.clone() for t in row_tensors for _ in range(cols)] - tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] +# rows, cols = self.mesh_shape +# row_dim, col_dim = self.dims - if len(tensor_shards) != rows * cols: - raise ValueError( - f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." - ) +# # Shard along rows +# row_tensors = ( +# [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) +# ) - return tensor_shards +# # Shard along columns +# if col_dim is None: +# return [t.clone() for t in row_tensors for _ in range(cols)] +# tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - def config(self) -> Dict[str, str]: - """ - Provide the configuration of the sharding strategy. +# if len(tensor_shards) != rows * cols: +# raise ValueError( +# f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." +# ) - Returns: - A dictionary containing the sharding strategy and dimensions. - """ - return { - "strategy": "shard_2d", - "mesh_shape_y": str(self.mesh_shape[0]), - "mesh_shape_x": str(self.mesh_shape[1]), - } +# return tensor_shards +# def config(self) -> Dict[str, str]: +# """ +# Provide the configuration of the sharding strategy. + +# Returns: +# A dictionary containing the sharding strategy and dimensions. +# """ +# return { +# "strategy": "shard_2d", +# "mesh_shape_y": str(self.mesh_shape[0]), +# "mesh_shape_x": str(self.mesh_shape[1]), +# } -class ConcatMesh2dToTensor(MeshToTensor): - """ - Concatenate tensors from a 2D mesh back into a single tensor. - - This class implements the inverse operation of ShardTensor2dMesh, combining - sharded tensors from a 2D device mesh back into a single tensor. - """ - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): - """ - Initialize the ConcatMesh2dToTensor. +# class ConcatMesh2dToTensor(MeshToTensor): +# """ +# Concatenate tensors from a 2D mesh back into a single tensor. - Args: - mesh_device: The source device mesh containing the sharded tensors. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. - The first element (row_dim) indicates the dimension for concatenating tensors from different rows. - The second element (col_dim) indicates the dimension for concatenating tensors from different columns. - Both dimensions must be specified and different from each other. - These dimensions correspond to the tensor dimensions, not the mesh dimensions. - For example, if the original tensor was 4D with shape (batch, channel, height, width), - and it was sharded across height and width, dims might be (-2, -1) or (2, 3). +# This class implements the inverse operation of ShardTensor2dMesh, combining +# sharded tensors from a 2D device mesh back into a single tensor. +# """ + +# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): +# """ +# Initialize the ConcatMesh2dToTensor. + +# Args: +# mesh_device: The source device mesh containing the sharded tensors. +# mesh_shape: The shape of the 2D mesh as (rows, cols). +# dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. +# The first element (row_dim) indicates the dimension for concatenating tensors from different rows. +# The second element (col_dim) indicates the dimension for concatenating tensors from different columns. +# Both dimensions must be specified and different from each other. +# These dimensions correspond to the tensor dimensions, not the mesh dimensions. +# For example, if the original tensor was 4D with shape (batch, channel, height, width), +# and it was sharded across height and width, dims might be (-2, -1) or (2, 3). - Raises: - ValueError: If either dimension in 'dims' is None or if both dimensions are the same. - """ - self.mesh_device = mesh_device - self.mesh_shape = mesh_shape - self.dims = dims - if self.dims[0] == self.dims[1]: - raise ValueError("Both dimensions in 'dims' must be different") +# Raises: +# ValueError: If either dimension in 'dims' is None or if both dimensions are the same. +# """ +# self.mesh_device = mesh_device +# self.mesh_shape = mesh_shape +# self.dims = dims +# if self.dims[0] == self.dims[1]: +# raise ValueError("Both dimensions in 'dims' must be different") - def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": - """ - Compose the sharded tensors back into a single tensor. +# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": +# """ +# Compose the sharded tensors back into a single tensor. - Args: - tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. +# Args: +# tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. - Returns: - A single torch.Tensor that combines all the sharded tensors from all devices. +# Returns: +# A single torch.Tensor that combines all the sharded tensors from all devices. - This method first concatenates the shards along the column dimension within each row, - then concatenates the resulting tensors along the row dimension to form the final tensor. - """ - import torch +# This method first concatenates the shards along the column dimension within each row, +# then concatenates the resulting tensors along the row dimension to form the final tensor. +# """ +# import torch - device_shards = [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) - ] +# device_shards = [ +# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) +# ] - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims +# rows, cols = self.mesh_shape +# row_dim, col_dim = self.dims - # Reshape the list of shards into a 2D list representing the device mesh - mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] +# # Reshape the list of shards into a 2D list representing the device mesh +# mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] - # Concatenate along columns first (within each row) - row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] +# # Concatenate along columns first (within each row) +# row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] - # Then concatenate the resulting tensors along rows - return torch.cat(row_concatenated, dim=row_dim) +# # Then concatenate the resulting tensors along rows +# return torch.cat(row_concatenated, dim=row_dim) -class ReplicateTensorToMesh(TensorToMesh): - def __init__(self, mesh_device: MeshDevice): - super().__init__(mesh_device) +# class ReplicateTensorToMesh(TensorToMesh): +# def __init__(self, mesh_device: MeshDevice): +# super().__init__(mesh_device) - def map(self, tensor: "torch.Tensor"): - return [tensor for i in range(self.mesh_device.get_num_devices())] +# def map(self, tensor: "torch.Tensor"): +# return [tensor for i in range(self.mesh_device.get_num_devices())] - def config(self): - return { - "strategy": "replicate", - "replication_factor": str(self.mesh_device.get_num_devices()), - } +# def config(self): +# return { +# "strategy": "replicate", +# "replication_factor": str(self.mesh_device.get_num_devices()), +# } -class ConcatMeshToTensor(MeshToTensor): - def __init__(self, mesh_device: MeshDevice, dim: int): - self.concat_dim = dim - self.mesh_device = mesh_device +# class ConcatMeshToTensor(MeshToTensor): +# def __init__(self, mesh_device: MeshDevice, dim: int): +# self.concat_dim = dim +# self.mesh_device = mesh_device - def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": - import torch +# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": +# import torch - device_shards_converted_to_torch = [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) - ] - return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) +# device_shards_converted_to_torch = [ +# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) +# ] +# return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) @contextlib.contextmanager -def distribute(default: Union[TensorToMesh, MeshToTensor]): +def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): """ Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. @@ -447,9 +465,9 @@ def distribute(default: Union[TensorToMesh, MeshToTensor]): _original_from_torch = ttnn.from_torch try: - if isinstance(default, TensorToMesh): + if isinstance(default, ttnn.TensorToMesh): ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, MeshToTensor): + elif isinstance(default, ttnn.MeshToTensor): ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) else: raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") From f4cb2496d675a66715685841ca3aa276bcb7a361 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 19:32:13 +0000 Subject: [PATCH 05/76] fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error --- .../distributed/test_distributed_tensor.py | 46 ++++++++++++------- .../ttnn/distributed/distributed_pybind.cpp | 39 ++++++++++++---- ttnn/ttnn/__init__.py | 6 +++ ttnn/ttnn/distributed/__init__.py | 1 - 4 files changed, 65 insertions(+), 27 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 7248bf0cf63..c766e7a63fc 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -34,15 +34,14 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) torch_tensor = torch.randn(1, 1, 32, 8192) - replicated = ttnn.from_torch( + to_repl = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - out_tensors = ttnn.get_device_tensors(mesh_device) + out_tensors = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) test = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device) @@ -51,6 +50,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): assert out_pass +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) @@ -60,12 +60,11 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), ) test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) - out_tensors = ttnn.get_device_tensors(tensor_shards) + out_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(tensor_shards) out_tensor = ttnn.aggregate_as_tensor(out_tensors) @@ -74,23 +73,23 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): assert out_pass +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) torch_tensor = torch.randn(1, 1, 8192, 32768) - sharded = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), ) test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) - out_tensor = ttnn.to_torch( - torch_tensor, dtype=ttnn.bfloat16, mesh_composer=ttnn.ConcatMeshToTensor(dim=3), device=mesh_device - ) + sharded_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + + out_tensor = ttnn.to_torch(ttnn.ConcatMeshToTensor(dim=3).compose(), dtype=ttnn.bfloat16, device=mesh_device) out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) logger.info(f"PCC value: {out_pcc}") @@ -104,6 +103,7 @@ def test_concat_to_tensor(mesh_device, dtype): "M, K, N", [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], ) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) @@ -124,16 +124,17 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - tensor_shards = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), ) - out_tensors = ttnn.get_device_tensors(tensor_shards) + out_tensors = ttnn.get_device_tensors( + ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) + ) for tensor in out_tensors: print(tensor) @@ -143,6 +144,14 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): # assert out_pass +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) @@ -164,17 +173,22 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - tensor_shards = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), + ) + + sharded_tensors = ttnn.get_device_tensors( + ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) ) out = ttnn.to_torch( - tensor_shards, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape).compose( + sharded_tensors + ), ) out_pass, out_pcc = comp_pcc(out, torch_tensor, pcc=0.99) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 42bf7a7b5e6..7848edde7a5 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -64,6 +64,13 @@ void py_module_types(py::module& module) { py::class_>( module, "ConcatMesh2dToTensor"); + py::class_(module, "ReplicateTensor"); + py::class_(module, "ShardTensor"); + py::class_(module, "ShardTensor2D"); + py::class_(module, "ShardMesh"); + py::class_(module, "AllGatherTensor"); + py::class_(module, "DistributedTensorConfig"); + py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); py::class_(module, "MeshShape", "Shape of a mesh device."); @@ -415,13 +422,11 @@ void py_module(py::module& module) { py::init([](MeshDevice& mesh_device) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); }), - py::kw_only(), py::arg("mesh_device")) .def( py::init([](size_t num_devices) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(num_devices)); }), - py::kw_only(), py::arg("num_devices")) .def( "map", @@ -436,14 +441,12 @@ void py_module(py::module& module) { py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { return std::make_unique(ShardTensorToMesh(mesh_device, dim)); }), - py::kw_only(), py::arg("mesh_device"), py::arg("dim")) .def( py::init([](size_t num_devices, int dim) -> std::unique_ptr { return std::make_unique(ShardTensorToMesh(num_devices, dim)); }), - py::kw_only(), py::arg("num_devices"), py::arg("dim")) .def( @@ -462,7 +465,6 @@ void py_module(py::module& module) { const Shard2dConfig& config) -> std::unique_ptr { return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); }), - py::kw_only(), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")) @@ -471,7 +473,6 @@ void py_module(py::module& module) { [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { return std::make_unique(ShardTensor2dMesh(mesh_shape, config)); }), - py::kw_only(), py::arg("mesh_shape"), py::arg("config")) .def( @@ -493,7 +494,6 @@ void py_module(py::module& module) { py::init([](int dim) -> std::unique_ptr { return std::make_unique(dim); }), - py::kw_only(), py::arg("dim")) .def( "compose", @@ -514,7 +514,6 @@ void py_module(py::module& module) { config.col_dim); return std::make_unique(mesh_device, config); }), - py::kw_only(), py::arg("mesh_device"), py::arg("config")) .def( @@ -574,6 +573,7 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); + // TODO: Add rdocs module.def( "replicate_tensor_to_mesh_mapper", [](MeshDevice& mesh_device) -> std::unique_ptr { @@ -608,8 +608,27 @@ void py_module(py::module& module) { }, py::arg("mesh_device"), py::arg("config")); - // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to - // aggregate + module.def( + "distribute_tensor", + [](const Tensor& tensor, + const TensorToMesh& mapper, + std::optional> mesh_device) -> Tensor { + return distribute_tensor(tensor, mapper, mesh_device); + }, + py::arg("tensor"), + py::arg("mapper")); + module.def( + "aggregate_tensor", + [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, + py::arg("tensor"), + py::arg("composer")); + module.def( + "aggregate_tensor", + [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { + return aggregate_tensor(aggregate_as_tensor(tensors, AllGatherTensor{}), composer); + }, + py::arg("tensor"), + py::arg("composer")); module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 87e58785cc2..8498b5ec5ac 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -103,6 +103,12 @@ def manage_config(name, value): ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor, + ReplicateTensor, + ShardTensor, + ShardTensor2d, + ShardMesh, + AllGatherTensor, + DistributedTensorConfig, get_device_tensor, get_device_tensors, aggregate_as_tensor, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 1e566d85567..e41931f36d5 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 from .distributed import ( - MeshDevice, DispatchCoreType, open_mesh_device, close_mesh_device, From 8fc6e5fd9d1cbf1fa68bde4736bd400c54ed2e87 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 19:38:34 +0000 Subject: [PATCH 06/76] add aggregate/distribute imports to init --- ttnn/ttnn/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 8498b5ec5ac..8b32b91bd40 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -112,6 +112,8 @@ def manage_config(name, value): get_device_tensor, get_device_tensors, aggregate_as_tensor, + aggregate_tensor, + distribute_tensor, get_t3k_physical_device_ids_ring, ) From f45b660888768ccce27ec33b4acfe433d86d4c15 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 20:58:21 +0000 Subject: [PATCH 07/76] add configs to pybind --- .../ttnn/distributed/distributed_pybind.cpp | 45 +++++++++++++++++-- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 7848edde7a5..b0012d8eddb 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -550,10 +550,46 @@ void py_module(py::module& module) { tensor (Tensor): The tensor to get the shard from. device_id (int): The device id to get the shard for. + Returns: + Tensor: The shard of the tensor corresponding to the device_id. + )doc"); + + auto py_replicate_tensor_config = static_cast>(module.attr("ShardTensor")); + py_replicate_tensor_config.def(py::init<>()) + .def(py::init(), py::arg("replication_factor") = 1) + .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) + .def("__eq__", [](const ReplicateTensor& a, const ReplicateTensor& b) { + return a.replication_factor == b.replication_factor; + }); + + auto py_shard_tensor_config = static_cast>(module.attr("ShardTensor")); + py_shard_tensor_config.def(py::init(), py::arg("shard_dimension")) + .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) + .def("__eq__", [](const ShardTensor& a, const ShardTensor& b) { return a == b; }); + + auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); + py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); + + auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2D")); + py_shard_tensor2d.def(py::init(), py::arg("mesh")) + .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) + .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); + + auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")); + .def(py::init<>()).def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); - Returns: - Tensor: The shard of the tensor corresponding to the device_id. - )doc"); + module.def( + "get_distributed_tensor_config", + &get_distributed_tensor_config, + py::arg("metadata"), + R"doc( + Returns a distributed_tensor_config object given a valid metadata object of the type + + { + "item": "field", + "item": "field", + } + )doc"); module.def( "get_device_tensor", py::overload_cast(&ttnn::distributed::get_device_tensor), @@ -616,7 +652,8 @@ void py_module(py::module& module) { return distribute_tensor(tensor, mapper, mesh_device); }, py::arg("tensor"), - py::arg("mapper")); + py::arg("mapper"), + py::arg("mesh_device")); module.def( "aggregate_tensor", [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, From a5759f53b5f48cb6bcf863f55b3bc3acd5cc7197 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 21:04:47 +0000 Subject: [PATCH 08/76] change test cases to use distribute/aggregate --- .../distributed/test_distributed_tensor.py | 85 ++++++++++++------- 1 file changed, 53 insertions(+), 32 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index c766e7a63fc..903cbf0a23c 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -11,6 +11,8 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc from ttnn import ( + distribute_tensor, + aggregate_tensor, ShardTensorToMesh, ShardTensor2dMesh, ReplicateTensorToMesh, @@ -41,11 +43,11 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - out_tensors = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) + mapper = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) + replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) + out_tensors = ttnn.get_device_tensors(replicated_tensors) - test = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device) - - out_pass, out_pcc = comp_pcc(out_tensors[0], test, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -55,20 +57,18 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) torch_tensor = torch.randn(1, 1, 8192, 32768) - tensor_shards = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, ) - test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) - out_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(tensor_shards) + out_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) - out_tensor = ttnn.aggregate_as_tensor(out_tensors) - - out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -80,18 +80,44 @@ def test_concat_to_tensor(mesh_device, dtype): torch_tensor = torch.randn(1, 1, 8192, 32768) to_shard = ttnn.from_torch( torch_tensor, - dtype=ttnn.bfloat16, + dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, ) - test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) - sharded_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + composer = ttnn.ConcatMeshToTensor(dim=3) - out_tensor = ttnn.to_torch(ttnn.ConcatMeshToTensor(dim=3).compose(), dtype=ttnn.bfloat16, device=mesh_device) + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat_slice_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 8192, 32768) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + + composer = ttnn.ConcatMeshToTensor(dim=3) + + out_tensor = [] + out_tensor[0] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-2], composer) + out_tensor[1] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-1], composer) + out_tensor[2] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:0], composer) + + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -132,16 +158,15 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - out_tensors = ttnn.get_device_tensors( - ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) - ) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) + + out_tensors = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - for tensor in out_tensors: - print(tensor) + ttnn.aggregate_as_tensor(out_tensors, mesh_device) - # out_pass, out_pcc = comp_pcc(tensor_shards, out, pcc=0.99) - # logger.info(f"PCC value: {out_pcc}") - # assert out_pass + out_pass, out_pcc = comp_pcc(out_tensors, torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass @pytest.mark.parametrize( @@ -181,16 +206,12 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - sharded_tensors = ttnn.get_device_tensors( - ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) - ) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - out = ttnn.to_torch( - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape).compose( - sharded_tensors - ), - ) + composer = ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass From cc14dd2fbda0c8320851a9554c2b3d7d0804606b Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 22:58:12 +0000 Subject: [PATCH 09/76] fix test mappers, convert to cpu_tensor --- .../distributed/test_distributed_tensor.py | 8 ++--- .../ttnn/distributed/distributed_pybind.cpp | 30 ++++++++++++++----- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 903cbf0a23c..5abb2c0d690 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -43,7 +43,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) + mapper = ttnn.ReplicateTensorToMesh(mesh_device) replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) out_tensors = ttnn.get_device_tensors(replicated_tensors) @@ -64,7 +64,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) out_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) @@ -85,7 +85,7 @@ def test_concat_to_tensor(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) composer = ttnn.ConcatMeshToTensor(dim=3) @@ -158,7 +158,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) out_tensors = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index b0012d8eddb..816550bf46c 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -13,11 +13,13 @@ #include #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" +#include "tt-metalium/assert.hpp" #include "ttnn/distributed/api.hpp" <<<<<<< HEAD #include "ttnn/distributed/types.hpp" ======= #include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/operations/core/core.hpp" #include "ttnn/tensor/tensor_utils.hpp" >>>>>>> one type error left #include "ttnn/tensor/tensor.hpp" @@ -53,6 +55,14 @@ struct ConcreteMeshToTensor : MeshToTensor { } }; +Tensor get_cpu_tensor(const Tensor& tensor) { + if (is_device_tensor(tensor)) { + Tensor cpu_tensor = tensor.cpu(); + TT_ASSERT(is_device_tensor(cpu_tensor)); + } + return tensor; +} + void py_module_types(py::module& module) { py::class_>(module, "MeshToTensor"); py::class_>(module, "TensorToMesh"); @@ -66,7 +76,7 @@ void py_module_types(py::module& module) { py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); - py::class_(module, "ShardTensor2D"); + py::class_(module, "ShardTensor2d"); py::class_(module, "ShardMesh"); py::class_(module, "AllGatherTensor"); py::class_(module, "DistributedTensorConfig"); @@ -557,7 +567,7 @@ void py_module(py::module& module) { auto py_replicate_tensor_config = static_cast>(module.attr("ShardTensor")); py_replicate_tensor_config.def(py::init<>()) .def(py::init(), py::arg("replication_factor") = 1) - .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) + .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) .def("__eq__", [](const ReplicateTensor& a, const ReplicateTensor& b) { return a.replication_factor == b.replication_factor; }); @@ -570,13 +580,15 @@ void py_module(py::module& module) { auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); - auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2D")); + auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); py_shard_tensor2d.def(py::init(), py::arg("mesh")) .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); - auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")); - .def(py::init<>()).def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); + auto py_allgather_config = + static_cast>(module.attr("AllGatherTensor")) + .def(py::init<>()) + .def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); module.def( "get_distributed_tensor_config", @@ -649,20 +661,22 @@ void py_module(py::module& module) { [](const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device) -> Tensor { - return distribute_tensor(tensor, mapper, mesh_device); + return distribute_tensor(get_cpu_tensor(tensor), mapper, mesh_device); }, py::arg("tensor"), py::arg("mapper"), py::arg("mesh_device")); module.def( "aggregate_tensor", - [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, + [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { + return aggregate_tensor(get_cpu_tensor(tensor), composer); + }, py::arg("tensor"), py::arg("composer")); module.def( "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { - return aggregate_tensor(aggregate_as_tensor(tensors, AllGatherTensor{}), composer); + return aggregate_tensor(get_cpu_tensor(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); }, py::arg("tensor"), py::arg("composer")); From bd1b931a9b883f10aadf34392248f1a80d219800 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:11:35 +0000 Subject: [PATCH 10/76] clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu --- .../distributed/test_distributed_tensor.py | 59 ++++++++----------- .../ttnn/distributed/distributed_pybind.cpp | 34 +++++++---- .../ttnn/distributed/distributed_pybind.hpp | 1 + ttnn/ttnn/__init__.py | 5 ++ 4 files changed, 53 insertions(+), 46 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 5abb2c0d690..3818a304470 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -3,24 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import typing import pytest import ttnn -import tempfile from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc - -from ttnn import ( - distribute_tensor, - aggregate_tensor, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, - ConcatMeshToTensor, - ConcatMesh2dToTensor, - MeshToTensor, - TensorToMesh, -) from models.utility_functions import nearest_32 @@ -43,7 +29,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ReplicateTensorToMesh(mesh_device) + mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) out_tensors = ttnn.get_device_tensors(replicated_tensors) @@ -64,11 +50,13 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - out_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) + out_tensor = ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -85,13 +73,13 @@ def test_concat_to_tensor(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - composer = ttnn.ConcatMeshToTensor(dim=3) + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -108,16 +96,17 @@ def test_concat_slice_to_tensor(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) + + sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) - composer = ttnn.ConcatMeshToTensor(dim=3) + shards = ttnn.get_device_tensors(sharded_tensor) - out_tensor = [] - out_tensor[0] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-2], composer) - out_tensor[1] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-1], composer) - out_tensor[2] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:0], composer) + out_tensor = ttnn.aggregate_tensor(shards, composer) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -158,13 +147,13 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - out_tensors = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - ttnn.aggregate_as_tensor(out_tensors, mesh_device) + ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(out_tensors, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -206,12 +195,12 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - composer = ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 816550bf46c..6a2be17bf9d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -4,9 +4,12 @@ #include "ttnn/distributed/distributed_pybind.hpp" #include +<<<<<<< HEAD #include #include #include +======= +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu #include @@ -20,10 +23,17 @@ ======= #include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/operations/core/core.hpp" +<<<<<<< HEAD #include "ttnn/tensor/tensor_utils.hpp" >>>>>>> one type error left +======= +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu #include "ttnn/tensor/tensor.hpp" +<<<<<<< HEAD #include "ttnn/types.hpp" +======= +#include +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -55,14 +65,6 @@ struct ConcreteMeshToTensor : MeshToTensor { } }; -Tensor get_cpu_tensor(const Tensor& tensor) { - if (is_device_tensor(tensor)) { - Tensor cpu_tensor = tensor.cpu(); - TT_ASSERT(is_device_tensor(cpu_tensor)); - } - return tensor; -} - void py_module_types(py::module& module) { py::class_>(module, "MeshToTensor"); py::class_>(module, "TensorToMesh"); @@ -611,6 +613,7 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device. +<<<<<<< HEAD Args: tensor (Tensor): The tensor to get the shard from. @@ -620,6 +623,15 @@ void py_module(py::module& module) { Returns: Tensor: The shard of the tensor corresponding to the device. )doc"); +======= + Args: + tensor (Tensor): The tensor to get the shard from. + device (Device): The device to get the shard for. +aggregate_as + Returns: + Tensor: The shard of the tensor corresponding to the device. + )doc"); +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); // TODO: Add rdocs module.def( @@ -661,7 +673,7 @@ void py_module(py::module& module) { [](const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device) -> Tensor { - return distribute_tensor(get_cpu_tensor(tensor), mapper, mesh_device); + return distribute_tensor(from_device(tensor), mapper, mesh_device); }, py::arg("tensor"), py::arg("mapper"), @@ -669,14 +681,14 @@ void py_module(py::module& module) { module.def( "aggregate_tensor", [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { - return aggregate_tensor(get_cpu_tensor(tensor), composer); + return aggregate_tensor(from_device(tensor), composer); }, py::arg("tensor"), py::arg("composer")); module.def( "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { - return aggregate_tensor(get_cpu_tensor(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); + return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); }, py::arg("tensor"), py::arg("composer")); diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp index 93d26f3f2d6..25c384363bc 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.hpp @@ -5,6 +5,7 @@ #pragma once #include "pybind11/pybind_fwd.hpp" #include +#include "pybind11/stl.h" namespace py = pybind11; diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 8b32b91bd40..a81899faaa7 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -112,6 +112,11 @@ def manage_config(name, value): get_device_tensor, get_device_tensors, aggregate_as_tensor, + replicate_tensor_to_mesh_mapper, + shard_tensor_to_mesh_mapper, + shard_tensor_to_2d_mesh_mapper, + concat_mesh_to_tensor_composer, + concat_2d_mesh_to_tensor_composer, aggregate_tensor, distribute_tensor, get_t3k_physical_device_ids_ring, From 3f26cb2a9c696e47fc3db942025d64165d310cbe Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:15:29 +0000 Subject: [PATCH 11/76] remove python implementations --- ttnn/ttnn/distributed/distributed.py | 232 --------------------------- 1 file changed, 232 deletions(-) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index d59377181df..b29089e72f6 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -209,238 +209,6 @@ def synchronize_devices( ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) -# class TensorToMesh: -# """ -# Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. -# You can also "Bring your own TensorToMesh" based on your custom mapping. -# """ - -# def __init__(self, mesh_device): -# self.mesh_device = mesh_device - -# def map(self, tensor: "torch.Tensor"): -# raise NotImplementedError("Subclasses must implement this method") - -# def config(self): -# raise NotImplementedError("Subclasses must implement this method") - - -# class MeshToTensor: -# """ -# Defines the inverse operation of TensorToMesh. Given a set of per-device -# ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines -# the mapping back to one or many torch.Tensor objects. - -# You can also "Bring your own MeshToTensor" based on your custom mapping. -# """ - -# def compose(self, tensor: ttnn.Tensor): -# raise NotImplementedError("Subclasses must implement this method") - - -# class ShardTensorToMesh(TensorToMesh): -# def __init__(self, mesh_device, dim): -# super().__init__(mesh_device) -# self.shard_dim = dim - -# def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: -# import torch - -# sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) -# return list(sliced_tensors) - -# def config(self): -# return { -# "strategy": "shard", -# "shard_dim": f"{self.shard_dim}", -# } - - -# class ShardTensor2dMesh(TensorToMesh): -# """ -# Shard a tensor across a 2D mesh of devices. - -# This class implements a strategy for distributing a tensor across a 2D grid of devices, -# allowing for efficient parallel processing in distributed computing environments. -# """ - -# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): -# """ -# Initialize the ShardTensor2dMesh. - -# Args: -# mesh_device: The target device mesh for distributing the tensor. -# mesh_shape: The shape of the 2D mesh as (rows, cols). -# dims: The dimensions to shard along, specified as (row_dim, col_dim). - -# The `dims` tuple determines how the tensor is sharded across the 2D mesh: -# - row_dim: The dimension to shard across mesh rows (or None for replication). -# - col_dim: The dimension to shard across mesh columns (or None for replication). - -# Examples: -# 1. dims=(2, 3) for a tensor of shape (A, B, C, D): -# - Shard along dimension 2 (C) across mesh rows -# - Shard along dimension 3 (D) across mesh columns - -# 2. dims=(None, 3): -# - Replicate across mesh rows -# - Shard along dimension 3 (D) across mesh columns - -# 3. dims=(None, None): -# - Fully replicate the tensor across all devices -# """ -# super().__init__(mesh_device) -# self.mesh_shape: Tuple[int, int] = mesh_shape -# self.dims: Tuple[Optional[int], Optional[int]] = dims - -# mesh_device_rows, mesh_device_cols = self.mesh_device.shape -# if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: -# raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - -# def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: -# """ -# Map the input tensor to a list of sharded tensors. - -# Args: -# tensor: The input tensor to be sharded. - -# Returns: -# A list of sharded tensors, one for each device in the mesh. - -# Raises: -# ValueError: If the number of sharding dimensions is not 2. -# """ -# import torch - -# if len(self.dims) != 2: -# raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - -# rows, cols = self.mesh_shape -# row_dim, col_dim = self.dims - -# # Shard along rows -# row_tensors = ( -# [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) -# ) - -# # Shard along columns -# if col_dim is None: -# return [t.clone() for t in row_tensors for _ in range(cols)] -# tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - -# if len(tensor_shards) != rows * cols: -# raise ValueError( -# f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." -# ) - -# return tensor_shards - -# def config(self) -> Dict[str, str]: -# """ -# Provide the configuration of the sharding strategy. - -# Returns: -# A dictionary containing the sharding strategy and dimensions. -# """ -# return { -# "strategy": "shard_2d", -# "mesh_shape_y": str(self.mesh_shape[0]), -# "mesh_shape_x": str(self.mesh_shape[1]), -# } - - -# class ConcatMesh2dToTensor(MeshToTensor): -# """ -# Concatenate tensors from a 2D mesh back into a single tensor. - -# This class implements the inverse operation of ShardTensor2dMesh, combining -# sharded tensors from a 2D device mesh back into a single tensor. -# """ - -# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): -# """ -# Initialize the ConcatMesh2dToTensor. - -# Args: -# mesh_device: The source device mesh containing the sharded tensors. -# mesh_shape: The shape of the 2D mesh as (rows, cols). -# dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. -# The first element (row_dim) indicates the dimension for concatenating tensors from different rows. -# The second element (col_dim) indicates the dimension for concatenating tensors from different columns. -# Both dimensions must be specified and different from each other. -# These dimensions correspond to the tensor dimensions, not the mesh dimensions. -# For example, if the original tensor was 4D with shape (batch, channel, height, width), -# and it was sharded across height and width, dims might be (-2, -1) or (2, 3). - -# Raises: -# ValueError: If either dimension in 'dims' is None or if both dimensions are the same. -# """ -# self.mesh_device = mesh_device -# self.mesh_shape = mesh_shape -# self.dims = dims -# if self.dims[0] == self.dims[1]: -# raise ValueError("Both dimensions in 'dims' must be different") - -# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": -# """ -# Compose the sharded tensors back into a single tensor. - -# Args: -# tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. - -# Returns: -# A single torch.Tensor that combines all the sharded tensors from all devices. - -# This method first concatenates the shards along the column dimension within each row, -# then concatenates the resulting tensors along the row dimension to form the final tensor. -# """ -# import torch - -# device_shards = [ -# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) -# ] - -# rows, cols = self.mesh_shape -# row_dim, col_dim = self.dims - -# # Reshape the list of shards into a 2D list representing the device mesh -# mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] - -# # Concatenate along columns first (within each row) -# row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] - -# # Then concatenate the resulting tensors along rows -# return torch.cat(row_concatenated, dim=row_dim) - - -# class ReplicateTensorToMesh(TensorToMesh): -# def __init__(self, mesh_device: MeshDevice): -# super().__init__(mesh_device) - -# def map(self, tensor: "torch.Tensor"): -# return [tensor for i in range(self.mesh_device.get_num_devices())] - -# def config(self): -# return { -# "strategy": "replicate", -# "replication_factor": str(self.mesh_device.get_num_devices()), -# } - - -# class ConcatMeshToTensor(MeshToTensor): -# def __init__(self, mesh_device: MeshDevice, dim: int): -# self.concat_dim = dim -# self.mesh_device = mesh_device - -# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": -# import torch - -# device_shards_converted_to_torch = [ -# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) -# ] -# return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) - - @contextlib.contextmanager def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): """ From cb232668237c0fed5f1f9a2b80ac9cc8bc3f757c Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:24:57 +0000 Subject: [PATCH 12/76] fix rebase --- .../ttnn/distributed/distributed_pybind.cpp | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 6a2be17bf9d..a3567ed21cc 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -4,15 +4,14 @@ #include "ttnn/distributed/distributed_pybind.hpp" #include -<<<<<<< HEAD #include -#include -#include -======= ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu +<<<<<<< HEAD #include +======= +#include +>>>>>>> fix rebase #include #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" @@ -29,11 +28,7 @@ ======= >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu #include "ttnn/tensor/tensor.hpp" -<<<<<<< HEAD -#include "ttnn/types.hpp" -======= #include ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -613,8 +608,6 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device. -<<<<<<< HEAD - Args: tensor (Tensor): The tensor to get the shard from. device (Device): The device to get the shard for. @@ -623,15 +616,6 @@ void py_module(py::module& module) { Returns: Tensor: The shard of the tensor corresponding to the device. )doc"); -======= - Args: - tensor (Tensor): The tensor to get the shard from. - device (Device): The device to get the shard for. -aggregate_as - Returns: - Tensor: The shard of the tensor corresponding to the device. - )doc"); ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); // TODO: Add rdocs module.def( From b05982655c1c67c67454a4985170b59c20f3615b Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:33:36 +0000 Subject: [PATCH 13/76] clean up deprecated imports --- ttnn/ttnn/distributed/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index e41931f36d5..c776d4d91f6 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -17,6 +17,7 @@ ReplicateTensorToMesh, MeshToTensor, ConcatMeshToTensor, + synchronize_devices, visualize_mesh_device, distribute, ) From 24dead90248d6337c5e503a6e7d1a4c528a8d0e4 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 19 Feb 2025 00:35:00 +0000 Subject: [PATCH 14/76] add shard2dconfig, concat2dconfig methods and map/compose constructors --- .../ttnn/distributed/distributed_pybind.cpp | 75 ++++++++++++++++--- .../ttnn/distributed/distributed_tensor.cpp | 8 ++ .../ttnn/distributed/distributed_tensor.hpp | 12 ++- ttnn/ttnn/__init__.py | 3 + 4 files changed, 85 insertions(+), 13 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index a3567ed21cc..2d5e80bdcbd 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -78,6 +78,9 @@ void py_module_types(py::module& module) { py::class_(module, "AllGatherTensor"); py::class_(module, "DistributedTensorConfig"); + py::class_(module, "Shard2dConfig"); + py::class_(module, "Concat2dConfig"); + py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); py::class_(module, "MeshShape", "Shape of a mesh device."); @@ -419,11 +422,9 @@ void py_module(py::module& module) { .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); - auto py_replicate_tensor_to_mesh = static_cast>>( module.attr("ReplicateTensorToMesh")); - py_replicate_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device) -> std::unique_ptr { @@ -440,7 +441,6 @@ void py_module(py::module& module) { [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); - auto py_shard_tensor_to_mesh = static_cast>>( module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh @@ -461,7 +461,6 @@ void py_module(py::module& module) { [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = static_cast>>( module.attr("ShardTensor2dMesh")); py_shard_tensor_to_2d_mesh @@ -472,6 +471,18 @@ void py_module(py::module& module) { const Shard2dConfig& config) -> std::unique_ptr { return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); }), + py::init( + [](MeshDevice& mesh_device, + const MeshShape& mesh_shape, + const std::tuple& config) -> std::unique_ptr { + return std::make_unique(ShardTensor2dMesh( + mesh_device, + mesh_shape, + Shard2dConfig{ + .row_dim = std::get<0>(config), + .col_dim = std::get<1>(config), + })); + }), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")) @@ -487,13 +498,11 @@ void py_module(py::module& module) { [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ShardTensor2dMesh::config); - auto py_mesh_to_tensor = static_cast>>( module.attr("MeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); - auto py_concat_mesh_to_tensor = static_cast>>( module.attr("ConcatMeshToTensor")); py_concat_mesh_to_tensor @@ -506,7 +515,6 @@ void py_module(py::module& module) { "compose", [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, py::arg("tensors")); - auto py_concat_2d_mesh_to_tensor = static_cast>>( module.attr("ConcatMesh2dToTensor")); @@ -521,6 +529,24 @@ void py_module(py::module& module) { config.col_dim); return std::make_unique(mesh_device, config); }), + py::init( + [](MeshDevice& mesh_device, + const std::tuple config) -> std::unique_ptr { + int row_dim = std::get<0>(config); + int col_dim = std::get<1>(config); + TT_FATAL( + row_dim != col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + row_dim, + col_dim); + return std::make_unique( + mesh_device, + Concat2dConfig{ + .row_dim = row_dim, + .col_dim = col_dim, + }); + }), + py::arg("mesh_device"), py::arg("config")) .def( @@ -541,7 +567,6 @@ void py_module(py::module& module) { py::arg("offset"), py::arg("physical_device_ids"), py::arg("dispatch_core_config")); - module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); module.def( "get_device_tensor", @@ -573,20 +598,26 @@ void py_module(py::module& module) { py_shard_tensor_config.def(py::init(), py::arg("shard_dimension")) .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) .def("__eq__", [](const ShardTensor& a, const ShardTensor& b) { return a == b; }); - auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); - auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); py_shard_tensor2d.def(py::init(), py::arg("mesh")) .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); - auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")) .def(py::init<>()) .def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); + auto py_shard2d_config = static_cast>(module.attr("Shard2dConfig")); + py_shard2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) + .def_readwrite("row_dim", &Shard2dConfig::row_dim) + .def_readwrite("col_dim", &Shard2dConfig::col_dim); + auto py_concat2d_config = static_cast>(module.attr("Concat2dConfig")); + py_concat2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) + .def_readwrite("row_dim", &Concat2dConfig::row_dim) + .def_readwrite("col_dim", &Concat2dConfig::col_dim); + module.def( "get_distributed_tensor_config", &get_distributed_tensor_config, @@ -599,6 +630,28 @@ void py_module(py::module& module) { "item": "field", } )doc"); + module.def( + "get_shard2d_config", + &get_shard2d_config, + py::arg("metadata"), + R"doc( + Returns a Shard2dConfig object given a valid metadata object of the type + { + "row_dim": "field", + "col_dim": "field", + } + )doc"); + module.def( + "get_concat2d_config", + &get_concat2d_config, + py::arg("metadata"), + R"doc( + Returns a Concat2dConfig object given a valid metadata object of the type + { + "row_dim": "field", + "col_dim": "field", + } + )doc"); module.def( "get_device_tensor", py::overload_cast(&ttnn::distributed::get_device_tensor), diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 897231fc8e6..f87e98baf4b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -206,4 +206,12 @@ Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { : composer.compose({tensor}); } +static Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { + return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); +} + +static Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { + return Concat2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); +} + } // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 7d45355c638..22c4803b437 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -94,7 +94,11 @@ class ShardTensor2dMesh : public TensorToMesh { } ShardTensor2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) {} + mesh_shape_(mesh_shape), config_(config) { + TT_FATAL( + config.row_dim.has_value() || config.col_dim.has_value(), + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); + } std::vector map(const Tensor& tensor) const override { const auto [rows, cols] = mesh_shape_; @@ -140,7 +144,7 @@ class ShardTensor2dMesh : public TensorToMesh { } tt::tt_metal::DistributedTensorConfig config() const override { - return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; + return DistributedTensorConfig{ShardTensor2D{ShardMesh{.y = mesh_shape_.num_rows, .x = mesh_shape_.num_cols}}}; } private: @@ -213,4 +217,8 @@ Tensor distribute_tensor( // Aggregates a multi-device tensor into a host tensor according to the `composer`. Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer); +Shard2dConfig get_shard2d_config(const std::unordered_map& metadata); + +Concat2dConfig get_concat2d_config(const std::unordered_map& metadata); + } // namespace ttnn::distributed diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index a81899faaa7..312eeb61551 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -111,6 +111,9 @@ def manage_config(name, value): DistributedTensorConfig, get_device_tensor, get_device_tensors, + get_shard2d_config, + get_concat2d_config, + get_distributed_tensor_config, aggregate_as_tensor, replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper, From 381de5d30ad63a2b7a4070f53b3f49e0f653ee52 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 19 Feb 2025 22:00:52 +0000 Subject: [PATCH 15/76] Replace none types, expose configs, fix tuple errors --- .../distributed/test_distributed_tensor.py | 4 +- .../ttnn/distributed/distributed_pybind.cpp | 160 ++++++++++++++++-- .../ttnn/distributed/distributed_tensor.cpp | 4 +- 3 files changed, 147 insertions(+), 21 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 3818a304470..b5456cfa0d2 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -126,7 +126,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] @@ -173,7 +173,7 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim concat_dim = (3, 1) if K < N else (1, 3) # dim 1 for reduce, dim 3 for concatenating fractures K = K // mesh_shape[1] if K < N else K // mesh_shape[0] diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 2d5e80bdcbd..25805cb0a2b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -21,6 +21,7 @@ #include "ttnn/distributed/types.hpp" ======= #include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/distributed/types.hpp" #include "ttnn/operations/core/core.hpp" <<<<<<< HEAD #include "ttnn/tensor/tensor_utils.hpp" @@ -97,6 +98,7 @@ void py_module(py::module& module) { py::arg("num_rows"), py::arg("num_cols")) .def( +<<<<<<< HEAD py::init([](size_t x, size_t y, size_t z) { return MeshShape(x, y, z); }), "Constructor with the specified 3D shape.", py::arg("x"), @@ -106,6 +108,13 @@ void py_module(py::module& module) { py::init([](const std::vector& shape) { return MeshShape(shape); }), "Constructor with the specified ND shape.", py::arg("shape")) +======= + py::init([](const std::tuple& dims) { return MeshShape(std::get<0>(dims), std::get<1>(dims)); }), + "Constructor with specified number of rows and columns as a tuple (rows, columns).", + py::arg("dims")) + .def_readwrite("num_rows", &MeshShape::num_rows, "Number of rows in the mesh.") + .def_readwrite("num_cols", &MeshShape::num_cols, "Number of columns in the mesh.") +>>>>>>> Replace none types, expose configs, fix tuple errors .def( "__repr__", [](const MeshShape& ms) { @@ -469,23 +478,24 @@ void py_module(py::module& module) { [](MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); + return std::make_unique(mesh_device, mesh_shape, config); }), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")) + .def( py::init( [](MeshDevice& mesh_device, - const MeshShape& mesh_shape, - const std::tuple& config) -> std::unique_ptr { - return std::make_unique(ShardTensor2dMesh( + const std::tuple dims, + const MeshShape& mesh_shape) -> std::unique_ptr { + return std::make_unique( mesh_device, mesh_shape, - Shard2dConfig{ - .row_dim = std::get<0>(config), - .col_dim = std::get<1>(config), - })); + Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("config")) + py::arg("dims"), + py::arg("mesh_shape")) .def( py::init( [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { @@ -529,11 +539,13 @@ void py_module(py::module& module) { config.col_dim); return std::make_unique(mesh_device, config); }), + py::arg("mesh_device"), + py::arg("config")) + .def( py::init( - [](MeshDevice& mesh_device, - const std::tuple config) -> std::unique_ptr { - int row_dim = std::get<0>(config); - int col_dim = std::get<1>(config); + [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + int row_dim = std::get<0>(dims); + int col_dim = std::get<1>(dims); TT_FATAL( row_dim != col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", @@ -546,12 +558,59 @@ void py_module(py::module& module) { .col_dim = col_dim, }); }), + py::arg("mesh_device"), + py::arg("dims")) + .def( + py::init( + [](MeshDevice& mesh_device, + const Concat2dConfig& config, + MeshShape& mesh_shape) -> std::unique_ptr { + TT_FATAL( + config.row_dim != config.col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + config.row_dim, + config.col_dim); + TT_FATAL( + mesh_shape.num_rows <= mesh_device.shape().num_rows && // + mesh_shape.num_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + return std::make_unique(mesh_device, config); + }), py::arg("mesh_device"), - py::arg("config")) + py::arg("config"), + py::arg("mesh_shape")) + .def( + py::init( + [](MeshDevice& mesh_device, + const std::tuple dims, + MeshShape& mesh_shape) -> std::unique_ptr { + int row_dim = std::get<0>(dims); + int col_dim = std::get<1>(dims); + + TT_FATAL( + row_dim != col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + row_dim, + col_dim); + TT_FATAL( + mesh_shape.num_rows <= mesh_device.shape().num_rows && // + mesh_shape.num_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + + return std::make_unique( + mesh_device, + Concat2dConfig{ + .row_dim = row_dim, + .col_dim = col_dim, + }); + }), + py::arg("mesh_device"), + py::arg("dims"), + py::arg("mesh_shape")) .def( "compose", - [](ConcatMesh2dToTensor self, const std::vector& tensors) -> Tensor { + [](const ConcatMesh2dToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); @@ -577,7 +636,6 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device_id. - Args: tensor (Tensor): The tensor to get the shard from. device_id (int): The device id to get the shard for. @@ -694,6 +752,30 @@ void py_module(py::module& module) { py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")); + module.def( + "shard_tensor_to_2d_mesh_mapper", + [](MeshDevice& mesh_device, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { + return shard_tensor_to_2d_mesh_mapper( + mesh_device, + MeshShape(std::get<0>(mesh_shape), std::get<1>(mesh_shape)), + Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("dims"), + R"doc( + Create a ShardTensor2dMesh mapper with the given mesh device, mesh shape, and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the mapper for. + mesh_shape (MeshShape): The shape of the 2D mesh as (num_rows, num_cols). + dims (Tuple[int, int]): The dimensions to create the mapper for in (row, column) format. + + Returns: + TensorToMesh: The created ShardTensor2dMesh mapper. + )doc"); module.def( "concat_mesh_to_tensor_composer", [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, @@ -705,6 +787,50 @@ void py_module(py::module& module) { }, py::arg("mesh_device"), py::arg("config")); + module.def( + "concat_2d_mesh_to_tensor_composer", + [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + return concat_2d_mesh_to_tensor_composer( + mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("dims"), + R"doc( + Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the composer for. + dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. + + Returns: + TensorToMesh: The created ConcatMesh2dToTensor composer. + )doc"); + module.def( + "concat_2d_mesh_to_tensor_composer", + [](MeshDevice& mesh_device, + const std::tuple dims, + const std::tuple mesh_shape) -> std::unique_ptr { + TT_FATAL( + std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // + std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + return concat_2d_mesh_to_tensor_composer( + mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("dims"), + py::arg("mesh_shape"), + R"doc( + Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the composer for. + dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. + mesh_shape (Tuple[int, int]): The shape of the 2D mesh as (num_rows, num_cols). + + Returns: + TensorToMesh: The created ConcatMesh2dToTensor composer. + )doc"); module.def( "distribute_tensor", [](const Tensor& tensor, diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index f87e98baf4b..f31617a7d7b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -206,11 +206,11 @@ Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { : composer.compose({tensor}); } -static Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { +Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } -static Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { +Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { return Concat2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } From 54dd2d448d1880ea826c843c687aec02f58b5de4 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 19 Feb 2025 22:41:08 +0000 Subject: [PATCH 16/76] overload for concatmeshtotensor with meshdevice --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 11 +++++++++++ ttnn/cpp/ttnn/distributed/distributed_tensor.hpp | 14 ++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 25805cb0a2b..ab905ee89af 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -521,10 +521,21 @@ void py_module(py::module& module) { return std::make_unique(dim); }), py::arg("dim")) + .def( + py::init([](MeshDevice mesh_device, int dim) -> std::unique_ptr { + return std::make_unique(mesh_device, dim); + }), + py::arg("mesh_device"), + py::arg("dim")) .def( "compose", [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, + py::arg("tensors")) + .def( + "compose", + [](const ConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, py::arg("tensors")); + auto py_concat_2d_mesh_to_tensor = static_cast>>( module.attr("ConcatMesh2dToTensor")); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 22c4803b437..6defda1727f 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -4,6 +4,7 @@ #pragma once +#include "tt-metalium/mesh_device.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/distributed/types.hpp" #include "ttnn/distributed/api.hpp" @@ -164,6 +165,19 @@ class ConcatMeshToTensor : public MeshToTensor { int concat_dim_ = -1; }; +class DeviceConcatMeshToTensor : public ConcatMeshToTensor { +public: + DeviceConcatMeshToTensor(MeshDevice mesh_device, int dim) : mesh_device_(mesh_device), concat_dim_(dim) {} + + Tensor compose(const Tensor& tensor) { + return experimental::xtensor::concat(get_device_tensors(tensor), concat_dim_); + } + +private: + MeshDevice mesh_device_; + int concat_dim_ = -1; +}; + class ConcatMesh2dToTensor : public MeshToTensor { public: ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : From d0678b35bda2210eb51cd5086e92950f34bde76b Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 05:08:54 +0000 Subject: [PATCH 17/76] remove extraneous comments --- tests/ttnn/distributed/test_distributed_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index b5456cfa0d2..614ffeffd03 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -126,7 +126,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim + shard_dim = (0, 3) if K < N else (3, 0) K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] @@ -173,8 +173,8 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim - concat_dim = (3, 1) if K < N else (1, 3) # dim 1 for reduce, dim 3 for concatenating fractures + shard_dim = (0, 3) if K < N else (3, 0) + concat_dim = (3, 1) if K < N else (1, 3) K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] From c2b9bc7076ac5e4b70463a84ad8476b50682eefb Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 05:47:43 +0000 Subject: [PATCH 18/76] fix deviceconcat errors --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 4 ++-- ttnn/cpp/ttnn/distributed/distributed_tensor.hpp | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index ab905ee89af..ccb754c641c 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -522,7 +522,7 @@ void py_module(py::module& module) { }), py::arg("dim")) .def( - py::init([](MeshDevice mesh_device, int dim) -> std::unique_ptr { + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { return std::make_unique(mesh_device, dim); }), py::arg("mesh_device"), @@ -533,7 +533,7 @@ void py_module(py::module& module) { py::arg("tensors")) .def( "compose", - [](const ConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, + [](const DeviceConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 6defda1727f..3057bbd1cb6 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -167,14 +167,15 @@ class ConcatMeshToTensor : public MeshToTensor { class DeviceConcatMeshToTensor : public ConcatMeshToTensor { public: - DeviceConcatMeshToTensor(MeshDevice mesh_device, int dim) : mesh_device_(mesh_device), concat_dim_(dim) {} + DeviceConcatMeshToTensor(MeshDevice& mesh_device, int dim) : + ConcatMeshToTensor(dim), mesh_device_(mesh_device), concat_dim_(dim) {} - Tensor compose(const Tensor& tensor) { + Tensor compose(const Tensor& tensor) const { return experimental::xtensor::concat(get_device_tensors(tensor), concat_dim_); } private: - MeshDevice mesh_device_; + MeshDevice& mesh_device_; int concat_dim_ = -1; }; From a452991ace150087fc4e402bc78c352298376dc8 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 17:16:07 +0000 Subject: [PATCH 19/76] add back distributed.py for now, clean up class overloads --- .../distributed/test_distributed_tensor.py | 12 +- .../ttnn/distributed/distributed_pybind.cpp | 198 +++++----------- .../ttnn/distributed/distributed_tensor.cpp | 8 + .../ttnn/distributed/distributed_tensor.hpp | 47 +--- ttnn/ttnn/__init__.py | 14 +- ttnn/ttnn/distributed/__init__.py | 9 + ttnn/ttnn/distributed/distributed.py | 222 +++++++++++++++++- 7 files changed, 309 insertions(+), 201 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 614ffeffd03..4d3b593287e 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -21,7 +21,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 8192) + torch_tensor = torch.randn(1, 1, 32, 256) to_repl = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -42,7 +42,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 8192, 32768) + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -65,7 +65,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 8192, 32768) + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -88,7 +88,7 @@ def test_concat_to_tensor(mesh_device, dtype): def test_concat_slice_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 8192, 32768) + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -116,7 +116,7 @@ def test_concat_slice_to_tensor(mesh_device, dtype): ) @pytest.mark.parametrize( "M, K, N", - [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], ) @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): @@ -163,7 +163,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): ) @pytest.mark.parametrize( "M, K, N", - [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], + [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], ) @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index ccb754c641c..32abbfef1a0 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -62,15 +62,17 @@ struct ConcreteMeshToTensor : MeshToTensor { }; void py_module_types(py::module& module) { - py::class_>(module, "MeshToTensor"); - py::class_>(module, "TensorToMesh"); + py::class_>(module, "CppMeshToTensor"); + py::class_>(module, "CppTensorToMesh"); + py::class_>( - module, "ReplicateTensorToMesh"); - py::class_>(module, "ShardTensorToMesh"); - py::class_>(module, "ShardTensor2dMesh"); - py::class_>(module, "ConcatMeshToTensor"); - py::class_>( - module, "ConcatMesh2dToTensor"); + module, "CppReplicateTensorToMesh"); + py::class_>(module, "CppShardTensorToMesh"); + py::class_>( + module, "CppShardTensorTo2dMesh"); + py::class_>(module, "CppConcatMeshToTensor"); + py::class_>( + module, "CppConcat2dMeshToTensor"); py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); @@ -426,190 +428,114 @@ void py_module(py::module& module) { )doc"); auto py_tensor_to_mesh = static_cast>>( - module.attr("TensorToMesh")); + module.attr("CppTensorToMesh")); py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); auto py_replicate_tensor_to_mesh = static_cast>>( - module.attr("ReplicateTensorToMesh")); + module.attr("CppReplicateTensorToMesh")); py_replicate_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); }), py::arg("mesh_device")) - .def( - py::init([](size_t num_devices) -> std::unique_ptr { - return std::make_unique(ReplicateTensorToMesh(num_devices)); - }), - py::arg("num_devices")) .def( "map", [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); auto py_shard_tensor_to_mesh = static_cast>>( - module.attr("ShardTensorToMesh")); + module.attr("CppShardTensorToMesh")); py_shard_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(ShardTensorToMesh(mesh_device, dim)); + return std::make_unique(ShardTensorToMesh(mesh_device.num_devices(), dim)); }), py::arg("mesh_device"), py::arg("dim")) - .def( - py::init([](size_t num_devices, int dim) -> std::unique_ptr { - return std::make_unique(ShardTensorToMesh(num_devices, dim)); - }), - py::arg("num_devices"), - py::arg("dim")) .def( "map", [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = static_cast>>( - module.attr("ShardTensor2dMesh")); + auto py_shard_tensor_to_2d_mesh = + static_cast>>( + module.attr("CppShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, - const MeshShape& mesh_shape, - const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(mesh_device, mesh_shape, config); - }), - py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("config")) - .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple dims, - const MeshShape& mesh_shape) -> std::unique_ptr { - return std::make_unique( - mesh_device, - mesh_shape, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { + int shape_rows = std::get<0>(mesh_shape); + int shape_cols = std::get<1>(mesh_shape); + + int config_rows = std::get<0>(dims); + int config_cols = std::get<1>(dims); + TT_FATAL( + config_rows || config_cols, + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); + TT_FATAL( + shape_rows <= mesh_device.shape().num_rows && // + shape_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + + return std::make_unique( + MeshShape{.num_rows = shape_rows, .num_cols = shape_cols}, Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), py::arg("mesh_device"), - py::arg("dims"), - py::arg("mesh_shape")) - .def( - py::init( - [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensor2dMesh(mesh_shape, config)); - }), py::arg("mesh_shape"), - py::arg("config")) + py::arg("dims")) .def( "map", - [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) - .def("config", &ShardTensor2dMesh::config); + .def("config", &ShardTensorTo2dMesh::config); auto py_mesh_to_tensor = static_cast>>( - module.attr("MeshToTensor")); + module.attr("CppMeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); auto py_concat_mesh_to_tensor = static_cast>>( - module.attr("ConcatMeshToTensor")); + module.attr("CppConcatMeshToTensor")); py_concat_mesh_to_tensor - .def( - py::init([](int dim) -> std::unique_ptr { - return std::make_unique(dim); - }), - py::arg("dim")) .def( py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(mesh_device, dim); + return std::make_unique(dim); }), py::arg("mesh_device"), py::arg("dim")) .def( "compose", [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, - py::arg("tensors")) - .def( - "compose", - [](const DeviceConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("ConcatMesh2dToTensor")); + static_cast>>( + module.attr("CppConcat2dMeshToTensor")); py_concat_2d_mesh_to_tensor .def( py::init( - [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { - TT_FATAL( - config.row_dim != config.col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - config.row_dim, - config.col_dim); - return std::make_unique(mesh_device, config); - }), - py::arg("mesh_device"), - py::arg("config")) - .def( - py::init( - [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + [](MeshDevice& mesh_device, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { int row_dim = std::get<0>(dims); int col_dim = std::get<1>(dims); TT_FATAL( - row_dim != col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - row_dim, - col_dim); - return std::make_unique( - mesh_device, - Concat2dConfig{ - .row_dim = row_dim, - .col_dim = col_dim, - }); - }), - py::arg("mesh_device"), - py::arg("dims")) - .def( - py::init( - [](MeshDevice& mesh_device, - const Concat2dConfig& config, - MeshShape& mesh_shape) -> std::unique_ptr { - TT_FATAL( - config.row_dim != config.col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - config.row_dim, - config.col_dim); - TT_FATAL( - mesh_shape.num_rows <= mesh_device.shape().num_rows && // - mesh_shape.num_cols <= mesh_device.shape().num_cols, + std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // + std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, "Device mesh shape does not match the provided mesh shape."); - return std::make_unique(mesh_device, config); - }), - py::arg("mesh_device"), - py::arg("config"), - py::arg("mesh_shape")) - .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple dims, - MeshShape& mesh_shape) -> std::unique_ptr { - int row_dim = std::get<0>(dims); - int col_dim = std::get<1>(dims); - TT_FATAL( row_dim != col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", row_dim, col_dim); - TT_FATAL( - mesh_shape.num_rows <= mesh_device.shape().num_rows && // - mesh_shape.num_cols <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - - return std::make_unique( + return std::make_unique( mesh_device, Concat2dConfig{ .row_dim = row_dim, @@ -617,11 +543,11 @@ void py_module(py::module& module) { }); }), py::arg("mesh_device"), - py::arg("dims"), - py::arg("mesh_shape")) + py::arg("Mesh_shape"), + py::arg("dims")) .def( "compose", - [](const ConcatMesh2dToTensor& self, const std::vector& tensors) -> Tensor { + [](const Concat2dMeshToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); @@ -798,29 +724,11 @@ void py_module(py::module& module) { }, py::arg("mesh_device"), py::arg("config")); - module.def( - "concat_2d_mesh_to_tensor_composer", - [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { - return concat_2d_mesh_to_tensor_composer( - mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); - }, - py::arg("mesh_device"), - py::arg("dims"), - R"doc( - Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. - - Args: - mesh_device (MeshDevice): The mesh device to create the composer for. - dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. - - Returns: - TensorToMesh: The created ConcatMesh2dToTensor composer. - )doc"); module.def( "concat_2d_mesh_to_tensor_composer", [](MeshDevice& mesh_device, - const std::tuple dims, - const std::tuple mesh_shape) -> std::unique_ptr { + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { TT_FATAL( std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index f31617a7d7b..37556c7324d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -161,11 +161,15 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( mesh_shape[0] <= mesh_device.shape()[0] && // mesh_shape[1] <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); +<<<<<<< HEAD <<<<<<< HEAD return std::make_unique(mesh_shape[0], mesh_shape[1], config); ======= return std::make_unique(mesh_shape, config); >>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice +======= + return std::make_unique(mesh_shape, config); +>>>>>>> add back distributed.py for now, clean up class overloads } std::unique_ptr concat_mesh_to_tensor_composer(int dim) { @@ -178,12 +182,16 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", config.row_dim, config.col_dim); +<<<<<<< HEAD <<<<<<< HEAD TT_FATAL(mesh_device.shape().dims() == 2, "Mesh device is not configured as a 2D mesh: {}", mesh_device.shape()); return std::make_unique(mesh_device.shape()[0], mesh_device.shape()[1], config); ======= return std::make_unique(mesh_device, config); >>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice +======= + return std::make_unique(mesh_device, config); +>>>>>>> add back distributed.py for now, clean up class overloads } Tensor distribute_tensor( diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 3057bbd1cb6..a381564fe66 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -45,8 +45,6 @@ class ReplicateTensorToMesh : public TensorToMesh { public: ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {} - ReplicateTensorToMesh(MeshDevice& mesh_device) : num_devices_(mesh_device.num_devices()) {} - std::vector map(const Tensor& tensor) const override { std::vector tensors; tensors.reserve(num_devices_); @@ -66,8 +64,6 @@ class ShardTensorToMesh : public TensorToMesh { public: ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} - ShardTensorToMesh(MeshDevice& mesh_device, int dim) : num_devices_(mesh_device.num_devices()), shard_dim_(dim) {} - std::vector map(const Tensor& tensor) const override { return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); } @@ -81,25 +77,10 @@ class ShardTensorToMesh : public TensorToMesh { int shard_dim_ = -1; }; -class ShardTensor2dMesh : public TensorToMesh { +class ShardTensorTo2dMesh : public TensorToMesh { public: - ShardTensor2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) { - TT_FATAL( - config.row_dim.has_value() || config.col_dim.has_value(), - "Sharding a tensor to 2D mesh requires at least one dimension to shard"); - TT_FATAL( - mesh_shape.num_rows <= mesh_device.shape().num_rows && // - mesh_shape.num_cols <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - } - - ShardTensor2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) { - TT_FATAL( - config.row_dim.has_value() || config.col_dim.has_value(), - "Sharding a tensor to 2D mesh requires at least one dimension to shard"); - } + ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : + mesh_shape_(mesh_shape), config_(config) {} std::vector map(const Tensor& tensor) const override { const auto [rows, cols] = mesh_shape_; @@ -135,7 +116,7 @@ class ShardTensor2dMesh : public TensorToMesh { TT_FATAL( static_cast(tensor_shards.size()) == rows * cols, - "ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " "dimensions. Size: {}, rows: {}, cols: {}", tensor_shards.size(), rows, @@ -145,7 +126,7 @@ class ShardTensor2dMesh : public TensorToMesh { } tt::tt_metal::DistributedTensorConfig config() const override { - return DistributedTensorConfig{ShardTensor2D{ShardMesh{.y = mesh_shape_.num_rows, .x = mesh_shape_.num_cols}}}; + return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; } private: @@ -165,23 +146,9 @@ class ConcatMeshToTensor : public MeshToTensor { int concat_dim_ = -1; }; -class DeviceConcatMeshToTensor : public ConcatMeshToTensor { -public: - DeviceConcatMeshToTensor(MeshDevice& mesh_device, int dim) : - ConcatMeshToTensor(dim), mesh_device_(mesh_device), concat_dim_(dim) {} - - Tensor compose(const Tensor& tensor) const { - return experimental::xtensor::concat(get_device_tensors(tensor), concat_dim_); - } - -private: - MeshDevice& mesh_device_; - int concat_dim_ = -1; -}; - -class ConcatMesh2dToTensor : public MeshToTensor { +class Concat2dMeshToTensor : public MeshToTensor { public: - ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : + Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : mesh_shape_(mesh_device.shape()), config_(config) {} Tensor compose(const std::vector& tensors) const override { diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 312eeb61551..18dbbc78cce 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -96,13 +96,13 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( MeshDevice, - MeshToTensor, - TensorToMesh, - ReplicateTensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ConcatMeshToTensor, - ConcatMesh2dToTensor, + # CppMeshToTensor, + # CppTensorToMesh, + # CppReplicateTensorToMesh, + # CppShardTensorToMesh, + # CppShardTensorTo2dMesh, + # CppConcatMeshToTensor, + # CppConcat2dMeshToTensor, ReplicateTensor, ShardTensor, ShardTensor2d, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index c776d4d91f6..4901c6ae8cb 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -2,8 +2,17 @@ # SPDX-License-Identifier: Apache-2.0 +# TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days from .distributed import ( + MeshDevice, DispatchCoreType, + TensorToMesh, + ShardTensorToMesh, + ShardTensor2dMesh, + ReplicateTensorToMesh, + MeshToTensor, + ConcatMeshToTensor, + ConcatMesh2dToTensor, open_mesh_device, close_mesh_device, get_num_pcie_devices, diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index b29089e72f6..fa057bd0051 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -209,8 +209,224 @@ def synchronize_devices( ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) +# TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days +class TensorToMesh: + """ + Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. + You can also "Bring your own TensorToMesh" based on your custom mapping. + """ + + def __init__(self, mesh_device): + self.mesh_device = mesh_device + + def map(self, tensor: "torch.Tensor"): + raise NotImplementedError("Subclasses must implement this method") + + def config(self): + raise NotImplementedError("Subclasses must implement this method") + + +class MeshToTensor: + """ + Defines the inverse operation of TensorToMesh. Given a set of per-device + ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines + the mapping back to one or many torch.Tensor objects. + You can also "Bring your own MeshToTensor" based on your custom mapping. + """ + + def compose(self, tensor: ttnn.Tensor): + raise NotImplementedError("Subclasses must implement this method") + + +class ShardTensorToMesh(TensorToMesh): + def __init__(self, mesh_device, dim): + super().__init__(mesh_device) + self.shard_dim = dim + + def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: + import torch + + sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) + return list(sliced_tensors) + + def config(self): + return { + "strategy": "shard", + "shard_dim": f"{self.shard_dim}", + } + + +class ShardTensor2dMesh(TensorToMesh): + """ + Shard a tensor across a 2D mesh of devices. + This class implements a strategy for distributing a tensor across a 2D grid of devices, + allowing for efficient parallel processing in distributed computing environments. + """ + + def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): + """ + Initialize the ShardTensor2dMesh. + Args: + mesh_device: The target device mesh for distributing the tensor. + mesh_shape: The shape of the 2D mesh as (rows, cols). + dims: The dimensions to shard along, specified as (row_dim, col_dim). + The `dims` tuple determines how the tensor is sharded across the 2D mesh: + - row_dim: The dimension to shard across mesh rows (or None for replication). + - col_dim: The dimension to shard across mesh columns (or None for replication). + Examples: + 1. dims=(2, 3) for a tensor of shape (A, B, C, D): + - Shard along dimension 2 (C) across mesh rows + - Shard along dimension 3 (D) across mesh columns + 2. dims=(None, 3): + - Replicate across mesh rows + - Shard along dimension 3 (D) across mesh columns + 3. dims=(None, None): + - Fully replicate the tensor across all devices + """ + super().__init__(mesh_device) + self.mesh_shape: Tuple[int, int] = mesh_shape + self.dims: Tuple[Optional[int], Optional[int]] = dims + + mesh_device_rows, mesh_device_cols = self.mesh_device.shape + if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: + raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") + + def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: + """ + Map the input tensor to a list of sharded tensors. + Args: + tensor: The input tensor to be sharded. + Returns: + A list of sharded tensors, one for each device in the mesh. + Raises: + ValueError: If the number of sharding dimensions is not 2. + """ + import torch + + if len(self.dims) != 2: + raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") + + rows, cols = self.mesh_shape + row_dim, col_dim = self.dims + + # Shard along rows + row_tensors = ( + [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) + ) + + # Shard along columns + if col_dim is None: + return [t.clone() for t in row_tensors for _ in range(cols)] + tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] + + if len(tensor_shards) != rows * cols: + raise ValueError( + f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." + ) + + return tensor_shards + + def config(self) -> Dict[str, str]: + """ + Provide the configuration of the sharding strategy. + Returns: + A dictionary containing the sharding strategy and dimensions. + """ + return { + "strategy": "shard_2d", + "mesh_shape_y": str(self.mesh_shape[0]), + "mesh_shape_x": str(self.mesh_shape[1]), + } + + +class ConcatMesh2dToTensor(MeshToTensor): + """ + Concatenate tensors from a 2D mesh back into a single tensor. + This class implements the inverse operation of ShardTensor2dMesh, combining + sharded tensors from a 2D device mesh back into a single tensor. + """ + + def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): + """ + Initialize the ConcatMesh2dToTensor. + Args: + mesh_device: The source device mesh containing the sharded tensors. + mesh_shape: The shape of the 2D mesh as (rows, cols). + dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. + The first element (row_dim) indicates the dimension for concatenating tensors from different rows. + The second element (col_dim) indicates the dimension for concatenating tensors from different columns. + Both dimensions must be specified and different from each other. + These dimensions correspond to the tensor dimensions, not the mesh dimensions. + For example, if the original tensor was 4D with shape (batch, channel, height, width), + and it was sharded across height and width, dims might be (-2, -1) or (2, 3). + Raises: + ValueError: If either dimension in 'dims' is None or if both dimensions are the same. + """ + self.mesh_device = mesh_device + self.mesh_shape = mesh_shape + self.dims = dims + if self.dims[0] == self.dims[1]: + raise ValueError("Both dimensions in 'dims' must be different") + + def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": + """ + Compose the sharded tensors back into a single tensor. + Args: + tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. + Returns: + A single torch.Tensor that combines all the sharded tensors from all devices. + This method first concatenates the shards along the column dimension within each row, + then concatenates the resulting tensors along the row dimension to form the final tensor. + """ + import torch + + device_shards = [ + ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ] + + rows, cols = self.mesh_shape + row_dim, col_dim = self.dims + + # Reshape the list of shards into a 2D list representing the device mesh + mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] + + # Concatenate along columns first (within each row) + row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] + + # Then concatenate the resulting tensors along rows + return torch.cat(row_concatenated, dim=row_dim) + + +class ReplicateTensorToMesh(TensorToMesh): + def __init__(self, mesh_device: MeshDevice): + super().__init__(mesh_device) + + def map(self, tensor: "torch.Tensor"): + return [tensor for i in range(self.mesh_device.get_num_devices())] + + def config(self): + return { + "strategy": "replicate", + "replication_factor": str(self.mesh_device.get_num_devices()), + } + + +class ConcatMeshToTensor(MeshToTensor): + def __init__(self, mesh_device: MeshDevice, dim: int): + self.concat_dim = dim + self.mesh_device = mesh_device + + def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": + import torch + + device_shards_converted_to_torch = [ + ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ] + return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) + + @contextlib.contextmanager -def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): +def distribute(default: Union[TensorToMesh, MeshToTensor]): """ Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. @@ -233,9 +449,9 @@ def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): _original_from_torch = ttnn.from_torch try: - if isinstance(default, ttnn.TensorToMesh): + if isinstance(default, TensorToMesh): ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, ttnn.MeshToTensor): + elif isinstance(default, MeshToTensor): ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) else: raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") From 24b703c429b358b440b50a84f37748176bf49945 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 17:25:24 +0000 Subject: [PATCH 20/76] remove unused import --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 32abbfef1a0..bfce3a3f6c1 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -4,7 +4,6 @@ #include "ttnn/distributed/distributed_pybind.hpp" #include -#include <<<<<<< HEAD #include From 795e2b140bfc814938d3761028b33ba27bbd0b50 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 20:31:53 +0000 Subject: [PATCH 21/76] rearrange from_torch.py, start migrating cpp classes and testing integration --- .../distributed/test_distributed_tensor.py | 30 +++++++++++- ttnn/ttnn/__init__.py | 14 +++--- ttnn/ttnn/operations/core.py | 47 ++++++++++--------- 3 files changed, 61 insertions(+), 30 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 4d3b593287e..b31f411fa27 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -10,6 +10,34 @@ from models.utility_functions import nearest_32 +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 32, 256) + to_repl = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.CppReplicateTensorToMesh(mesh_device) + replicated_tensors = ttnn.from_torch(to_repl, mapper, mesh_device) + out_tensors = ttnn.get_device_tensors(replicated_tensors) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + @pytest.mark.parametrize( "mesh_device", [ diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 18dbbc78cce..dce198ccc88 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -96,13 +96,13 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( MeshDevice, - # CppMeshToTensor, - # CppTensorToMesh, - # CppReplicateTensorToMesh, - # CppShardTensorToMesh, - # CppShardTensorTo2dMesh, - # CppConcatMeshToTensor, - # CppConcat2dMeshToTensor, + CppMeshToTensor, + CppTensorToMesh, + CppReplicateTensorToMesh, + CppShardTensorToMesh, + CppShardTensorTo2dMesh, + CppConcatMeshToTensor, + CppConcat2dMeshToTensor, ReplicateTensor, ShardTensor, ShardTensor2d, diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 409480605bb..1529bb328e4 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -156,7 +156,7 @@ def from_torch( layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, - mesh_mapper: Optional[ttnn.TensorToMesh] = None, + mesh_mapper: Optional[Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh]] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> ttnn.Tensor: """ @@ -194,8 +194,17 @@ def from_torch( if memory_config.shard_spec.mode == ttnn.ShardMode.LOGICAL: return ttnn.Tensor(tensor, dtype, device, layout, memory_config, tile) + if memory_config is not None: + if device is None: + raise RuntimeError("ttnn.from_torch: device must be specified when memory_config is specified") + + if pad_value is not None: + if layout != ttnn.TILE_LAYOUT: + raise RuntimeError("ttnn.from_torch: layout must be TILE_LAYOUT when pad_value is specified") + logical_shape = None padded_shape = None + if dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b: if layout != ttnn.TILE_LAYOUT: raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!") @@ -204,27 +213,18 @@ def from_torch( logical_shape = tensor.shape padded_shape = tensor.padded_shape tensor = tensor.reshape(tensor.padded_shape) - tensor = ttnn.to_torch(tensor) - - if memory_config is not None: - if device is None: - raise RuntimeError("ttnn.from_torch: device must be specified when memory_config is specified") - - if pad_value is not None: - if layout != ttnn.TILE_LAYOUT: - raise RuntimeError("ttnn.from_torch: layout must be TILE_LAYOUT when pad_value is specified") + else: + tensor = ttnn.Tensor(tensor, dtype) if mesh_mapper: - shards = mesh_mapper.map(tensor) - if tile is not None: - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config(), tile) - else: - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) - else: - if tile is not None: - tensor = ttnn.Tensor(tensor, dtype, {}, tile) + if isinstance(mesh_mapper, ttnn.MeshToTensor): + shards = mesh_mapper.map(ttnn.to_torch(tensor)) else: - tensor = ttnn.Tensor(tensor, dtype) + shards = mesh_mapper.map(tensor) + tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) + + if tile is not None: + tensor = ttnn.Tensor(tensor, dtype, {}, tile) if layout is not None and not (dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b): if pad_value is not None: @@ -269,7 +269,7 @@ def to_torch( dtype: Optional[torch.dtype] = None, *, torch_rank: Optional[int] = None, - mesh_composer: Optional[ttnn.MeshToTensor] = None, + mesh_composer: Optional[Union[ttnn.MeshToTensor, ttnn.CppMeshToTensor]] = None, device: Optional[ttnn.Device] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> "torch.Tensor": @@ -302,7 +302,10 @@ def to_torch( tensor = ttnn.from_device(tensor, cq_id=cq_id) if mesh_composer: - return mesh_composer.compose(tensor) + if isinstance(mesh_composer, ttnn.MeshToTensor): + return mesh_composer.compose(tensor) + else: + return mesh_composer.compose(ttnn.get_device_tensors(tensor)) if tensor.storage_type() == ttnn.DEVICE_STORAGE_TYPE: raise RuntimeError("ttnn.Tensor cannot be on device when converting to torch.Tensor!") @@ -520,7 +523,7 @@ def as_tensor( memory_config: Optional[ttnn.MemoryConfig] = None, cache_file_name: Optional[Union[str, pathlib.Path]] = None, preprocess: Optional[Callable[[ttnn.Tensor], ttnn.Tensor]] = None, - mesh_mapper: Optional[ttnn.TensorToMesh] = None, + mesh_mapper: Optional[Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh]] = None, use_device_tilizer: bool = False, ) -> ttnn.Tensor: """ From 5c160a9bdda0f2ef16a8e42fe83a74fa731a9e86 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 7 Feb 2025 18:53:04 +0000 Subject: [PATCH 22/76] expose classes to python --- .../ttnn/distributed/distributed_pybind.cpp | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index bfce3a3f6c1..d4349aea6e0 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -3,18 +3,28 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/distributed/distributed_pybind.hpp" +<<<<<<< HEAD #include <<<<<<< HEAD #include +======= +#include +#include +#include +>>>>>>> expose classes to python ======= #include >>>>>>> fix rebase #include +<<<<<<< HEAD #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" #include "tt-metalium/assert.hpp" +======= +#include "distributed_tensor.hpp" +>>>>>>> expose classes to python #include "ttnn/distributed/api.hpp" <<<<<<< HEAD #include "ttnn/distributed/types.hpp" @@ -61,6 +71,7 @@ struct ConcreteMeshToTensor : MeshToTensor { }; void py_module_types(py::module& module) { +<<<<<<< HEAD py::class_>(module, "CppMeshToTensor"); py::class_>(module, "CppTensorToMesh"); @@ -82,6 +93,15 @@ void py_module_types(py::module& module) { py::class_(module, "Shard2dConfig"); py::class_(module, "Concat2dConfig"); +======= + py::class_>(module, "MeshToTensor"); + py::class_>(module, "TensorToMesh"); + py::class_(module, "TensorToMesh"); + py::class_(module, "ShardTensorToMesh"); + py::class_(module, "ShardTensorTo2dMesh"); + py::class_(module, "ConcatMeshToTensor"); + py::class_(module, "Concat2dMeshToTensor"); +>>>>>>> expose classes to python py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); @@ -426,6 +446,7 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); +<<<<<<< HEAD auto py_tensor_to_mesh = static_cast>>( module.attr("CppTensorToMesh")); py_tensor_to_mesh @@ -550,6 +571,95 @@ void py_module(py::module& module) { return self.compose(tensors); }, py::arg("tensors")); +======= + auto py_tensor_to_mesh = static_cast>>(module.attr("TensorToMesh")); + py_tensor_to_mesh + .def(py::init<>(MeshDevice & mesh_device), + py::kw_only(), + py::arg("mesh_device")) + .def("map", &TensorToMesh::map) + .def("config", &TensorToMesh::config); + + auto py_replicate_tensor_to_mesh = static_cast>( + module.attr("ReplicateTensorToMesh")); + py_replicate_tensor_to_mesh + .def(py::init<>(MeshDevice & mesh_device) { + return replicate_tensor_to_mesh_mapper(mesh_device); + }, + py::kw_only(), + py::arg("mesh_device")) + .def(py::init<>() + py::kw_only()) + .def("map",[](self, const Tensor& tensor) { + return self.map(tensor); + }, + py::arg("tensor") + .def("config", &ReplicateTensorToMesh::config); + + auto py_shard_tensor_to_mesh = static_cast>( + module.attr("ShardTensorToMesh")); + py_shard_tensor_to_mesh + .def(py::init<>(MeshDevice & mesh_device, int dim) { + return shard_tensor_to_mesh_mapper(mesh_device, dim); + }, + py::kw_only(), + py::arg("mesh_device"), + py::arg("dim")) + .def(py::init<>() + py::kw_only()) + .def("map",[](self, const Tensor& tensor) { + return self.map(tensor); + }, + py::arg("tensor")) + .def("config", &ShardTensorToMesh::config); + + auto py_shard_tensor_to_2d_mesh = static_cast>(module.attr("ShardTensorTo2dMesh")); + py_shard_tensor_to_2d_mesh + .def(py::init<>(MeshDevice & mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) { + return shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + }, + py::kw_only(), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")) + .def(py::init<>() + py::kw_only()) + .def("map",[](self, const Tensor& tensor) { + return self.map(tensor); + }, + py::arg("tensor")) + .def("config", &ShardTensorTo2dMesh::config); + + auto py_mesh_to_tensor = static_cast>>(module.attr("MeshToTensor")); + py_mesh_to_tensor + .def(py::init<>) + .def("compose", &MeshToTensor::compose); + + auto py_concat_mesh_to_tensor = static_cast>(module.attr("ConcatMeshToTensor")); + py_concat_mesh_to_tensor + .def(py::init<>(int dim) { + return concat_mesh_to_tensor_composer(dim); + }, + py::kw_only(), + py::arg("dim")) + .def("compose",[](self, const std::vector& tensors) { + return self.compose(tensors); + }, + py::arg("tensors")); + + auto py_concat_2d_mesh_to_tensor = static_cast>(module.attr("Concat2dMeshToTensor")); + py_concat_2d_mesh_to_tensor + .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { + return concat_2d_mesh_to_tensor_composer(mesh_device, config); + }, + py::kw_only(), + py::arg("mesh_device"), + py::arg("config")) + .def("compose",[](self, const std::vector& tensors) { + return self.compose(tensors); + }, + .py::arg("tensors")); +>>>>>>> expose classes to python module.def( "open_mesh_device", @@ -664,6 +774,7 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); +<<<<<<< HEAD // TODO: Add rdocs module.def( "replicate_tensor_to_mesh_mapper", @@ -773,11 +884,15 @@ void py_module(py::module& module) { }, py::arg("tensor"), py::arg("composer")); +======= + //TODO: overload this method to enable selection of a subset of shards with a config or something before passing to aggregate +>>>>>>> expose classes to python module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, py::arg("tensors"), py::kw_only()); + module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } From 5db773538cdc3a235c4e5623eee35d4e392320b6 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Sat, 8 Feb 2025 00:23:07 +0000 Subject: [PATCH 23/76] one type error left --- .../ttnn/distributed/distributed_pybind.cpp | 191 ++++++++++++------ 1 file changed, 133 insertions(+), 58 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index d4349aea6e0..a5de4a9ec1b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -4,11 +4,15 @@ #include "ttnn/distributed/distributed_pybind.hpp" <<<<<<< HEAD +<<<<<<< HEAD #include <<<<<<< HEAD #include ======= +======= +#include +>>>>>>> one type error left #include #include #include @@ -24,6 +28,7 @@ #include "tt-metalium/assert.hpp" ======= #include "distributed_tensor.hpp" +<<<<<<< HEAD >>>>>>> expose classes to python #include "ttnn/distributed/api.hpp" <<<<<<< HEAD @@ -37,6 +42,12 @@ >>>>>>> one type error left ======= >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu +======= +#include "distributed_tensor.cpp" +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +>>>>>>> one type error left #include "ttnn/tensor/tensor.hpp" #include @@ -71,6 +82,7 @@ struct ConcreteMeshToTensor : MeshToTensor { }; void py_module_types(py::module& module) { +<<<<<<< HEAD <<<<<<< HEAD py::class_>(module, "CppMeshToTensor"); py::class_>(module, "CppTensorToMesh"); @@ -102,6 +114,17 @@ void py_module_types(py::module& module) { py::class_(module, "ConcatMeshToTensor"); py::class_(module, "Concat2dMeshToTensor"); >>>>>>> expose classes to python +======= + py::class_>(module, "MeshToTensor"); + py::class_>(module, "TensorToMesh"); + py::class_>( + module, "ReplicateTensorToMesh"); + py::class_>(module, "ShardTensorToMesh"); + py::class_>(module, "ShardTensorTo2dMesh"); + py::class_>(module, "ConcatMeshToTensor"); + py::class_>( + module, "Concat2dMeshToTensor"); +>>>>>>> one type error left py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); @@ -446,6 +469,7 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); +<<<<<<< HEAD <<<<<<< HEAD auto py_tensor_to_mesh = static_cast>>( module.attr("CppTensorToMesh")); @@ -573,82 +597,114 @@ void py_module(py::module& module) { py::arg("tensors")); ======= auto py_tensor_to_mesh = static_cast>>(module.attr("TensorToMesh")); +======= + auto py_tensor_to_mesh = + static_cast>>(module.attr("TensorToMesh")); +>>>>>>> one type error left py_tensor_to_mesh - .def(py::init<>(MeshDevice & mesh_device), - py::kw_only(), - py::arg("mesh_device")) + .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); - auto py_replicate_tensor_to_mesh = static_cast>( - module.attr("ReplicateTensorToMesh")); + auto py_replicate_tensor_to_mesh = + static_cast>>( + module.attr("ReplicateTensorToMesh")); + py_replicate_tensor_to_mesh - .def(py::init<>(MeshDevice & mesh_device) { - return replicate_tensor_to_mesh_mapper(mesh_device); - }, - py::kw_only(), - py::arg("mesh_device")) - .def(py::init<>() - py::kw_only()) - .def("map",[](self, const Tensor& tensor) { - return self.map(tensor); - }, - py::arg("tensor") + .def( + py::init([](MeshDevice& mesh_device) -> std::unique_ptr { + return ttnn::distributed::replicate_tensor_to_mesh_mapper(mesh_device); + }), + py::kw_only(), + py::arg("mesh_device")) + .def( + py::init([](size_t num_devices) -> std::unique_ptr { + return std::make_unique(ReplicateTensorToMesh(num_devices)); + }), + py::kw_only(), + py::arg("num_devices")) + .def( + "map", + [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, + py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); - auto py_shard_tensor_to_mesh = static_cast>( - module.attr("ShardTensorToMesh")); + auto py_shard_tensor_to_mesh = static_cast>>( + module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh - .def(py::init<>(MeshDevice & mesh_device, int dim) { - return shard_tensor_to_mesh_mapper(mesh_device, dim); - }, - py::kw_only(), - py::arg("mesh_device"), - py::arg("dim")) - .def(py::init<>() - py::kw_only()) - .def("map",[](self, const Tensor& tensor) { - return self.map(tensor); - }, - py::arg("tensor")) + .def( + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return ttnn::distributed::shard_tensor_to_mesh_mapper(mesh_device, dim); + }), + py::kw_only(), + py::arg("mesh_device"), + py::arg("dim")) + .def( + py::init([](size_t num_devices, int dim) -> std::unique_ptr { + return std::make_unique(ShardTensorToMesh(num_devices, dim)); + }), + py::kw_only(), + py::arg("num_devices"), + py::arg("dim")) + .def( + "map", + [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, + py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = static_cast>(module.attr("ShardTensorTo2dMesh")); + auto py_shard_tensor_to_2d_mesh = + static_cast>>( + module.attr("ShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh - .def(py::init<>(MeshDevice & mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) { - return shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); - }, - py::kw_only(), - py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("config")) - .def(py::init<>() - py::kw_only()) - .def("map",[](self, const Tensor& tensor) { - return self.map(tensor); - }, - py::arg("tensor")) + .def( + py::init( + [](MeshDevice& mesh_device, + const MeshShape& mesh_shape, + const Shard2dConfig& config) -> std::unique_ptr { + return ttnn::distributed::shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + }), + py::kw_only(), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")) + .def( + py::init([](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); + }), + py::kw_only(), + py::arg("mesh_shape"), + py::arg("config")) + .def( + "map", + [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + py::arg("tensor")) .def("config", &ShardTensorTo2dMesh::config); - auto py_mesh_to_tensor = static_cast>>(module.attr("MeshToTensor")); + auto py_mesh_to_tensor = + static_cast>>(module.attr("MeshToTensor")); py_mesh_to_tensor - .def(py::init<>) + .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); - auto py_concat_mesh_to_tensor = static_cast>(module.attr("ConcatMeshToTensor")); + auto py_concat_mesh_to_tensor = static_cast>>( + module.attr("ConcatMeshToTensor")); py_concat_mesh_to_tensor - .def(py::init<>(int dim) { - return concat_mesh_to_tensor_composer(dim); - }, - py::kw_only(), - py::arg("dim")) - .def("compose",[](self, const std::vector& tensors) { - return self.compose(tensors); - }, - py::arg("tensors")); + .def( + py::init([](int dim) -> std::unique_ptr { + return ttnn::distributed::concat_mesh_to_tensor_composer(dim); + }), + py::kw_only(), + py::arg("dim")) + .def( + "compose", + [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, + py::arg("tensors")); - auto py_concat_2d_mesh_to_tensor = static_cast>(module.attr("Concat2dMeshToTensor")); + auto py_concat_2d_mesh_to_tensor = + static_cast>>( + module.attr("Concat2dMeshToTensor")); py_concat_2d_mesh_to_tensor +<<<<<<< HEAD .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { return concat_2d_mesh_to_tensor_composer(mesh_device, config); }, @@ -660,6 +716,21 @@ void py_module(py::module& module) { }, .py::arg("tensors")); >>>>>>> expose classes to python +======= + .def( + py::init([](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + return ttnn::distributed::concat_2d_mesh_to_tensor_composer(mesh_device, config); + }), + py::kw_only(), + py::arg("mesh_device"), + py::arg("config")) + .def( + "compose", + [](Concat2dMeshToTensor self, const std::vector& tensors) -> Tensor { + return self.compose(tensors); + }, + py::arg("tensors")); +>>>>>>> one type error left module.def( "open_mesh_device", @@ -774,6 +845,7 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); +<<<<<<< HEAD <<<<<<< HEAD // TODO: Add rdocs module.def( @@ -887,12 +959,15 @@ void py_module(py::module& module) { ======= //TODO: overload this method to enable selection of a subset of shards with a config or something before passing to aggregate >>>>>>> expose classes to python +======= + // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to + // aggregate +>>>>>>> one type error left module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, py::arg("tensors"), py::kw_only()); - module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } From a21afeb86bb0aea49e5391a83aa88fb0ebc7dbcf Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Mon, 10 Feb 2025 17:54:11 +0000 Subject: [PATCH 24/76] move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors --- .../ttnn/distributed/distributed_pybind.cpp | 60 +++++++++++++------ .../ttnn/distributed/distributed_tensor.hpp | 22 +++++++ 2 files changed, 64 insertions(+), 18 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index a5de4a9ec1b..49198ddc3af 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -29,6 +29,7 @@ ======= #include "distributed_tensor.hpp" <<<<<<< HEAD +<<<<<<< HEAD >>>>>>> expose classes to python #include "ttnn/distributed/api.hpp" <<<<<<< HEAD @@ -44,6 +45,8 @@ >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu ======= #include "distributed_tensor.cpp" +======= +>>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors #include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/tensor/tensor_utils.hpp" @@ -469,6 +472,7 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); +<<<<<<< HEAD <<<<<<< HEAD <<<<<<< HEAD auto py_tensor_to_mesh = static_cast>>( @@ -601,6 +605,10 @@ void py_module(py::module& module) { auto py_tensor_to_mesh = static_cast>>(module.attr("TensorToMesh")); >>>>>>> one type error left +======= + auto py_tensor_to_mesh = static_cast>>( + module.attr("TensorToMesh")); +>>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) @@ -612,13 +620,13 @@ void py_module(py::module& module) { py_replicate_tensor_to_mesh .def( - py::init([](MeshDevice& mesh_device) -> std::unique_ptr { - return ttnn::distributed::replicate_tensor_to_mesh_mapper(mesh_device); + py::init([](MeshDevice& mesh_device) -> std::unique_ptr { + return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); }), py::kw_only(), py::arg("mesh_device")) .def( - py::init([](size_t num_devices) -> std::unique_ptr { + py::init([](size_t num_devices) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(num_devices)); }), py::kw_only(), @@ -633,14 +641,14 @@ void py_module(py::module& module) { module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return ttnn::distributed::shard_tensor_to_mesh_mapper(mesh_device, dim); + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return std::make_unique(ShardTensorToMesh(mesh_device, dim)); }), py::kw_only(), py::arg("mesh_device"), py::arg("dim")) .def( - py::init([](size_t num_devices, int dim) -> std::unique_ptr { + py::init([](size_t num_devices, int dim) -> std::unique_ptr { return std::make_unique(ShardTensorToMesh(num_devices, dim)); }), py::kw_only(), @@ -660,17 +668,18 @@ void py_module(py::module& module) { py::init( [](MeshDevice& mesh_device, const MeshShape& mesh_shape, - const Shard2dConfig& config) -> std::unique_ptr { - return ttnn::distributed::shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape, config); + const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensorTo2dMesh(mesh_device, mesh_shape, config)); }), py::kw_only(), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")) .def( - py::init([](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); - }), + py::init( + [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); + }), py::kw_only(), py::arg("mesh_shape"), py::arg("config")) @@ -680,8 +689,8 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ShardTensorTo2dMesh::config); - auto py_mesh_to_tensor = - static_cast>>(module.attr("MeshToTensor")); + auto py_mesh_to_tensor = static_cast>>( + module.attr("MeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); @@ -690,8 +699,8 @@ void py_module(py::module& module) { module.attr("ConcatMeshToTensor")); py_concat_mesh_to_tensor .def( - py::init([](int dim) -> std::unique_ptr { - return ttnn::distributed::concat_mesh_to_tensor_composer(dim); + py::init([](int dim) -> std::unique_ptr { + return std::make_unique(dim); }), py::kw_only(), py::arg("dim")) @@ -718,9 +727,15 @@ void py_module(py::module& module) { >>>>>>> expose classes to python ======= .def( - py::init([](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { - return ttnn::distributed::concat_2d_mesh_to_tensor_composer(mesh_device, config); - }), + py::init( + [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + TT_FATAL( + config.row_dim != config.col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + config.row_dim, + config.col_dim); + return std::make_unique(mesh_device, config); + }), py::kw_only(), py::arg("mesh_device"), py::arg("config")) @@ -846,8 +861,11 @@ void py_module(py::module& module) { )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD // TODO: Add rdocs +======= +>>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors module.def( "replicate_tensor_to_mesh_mapper", [](MeshDevice& mesh_device) -> std::unique_ptr { @@ -872,6 +890,7 @@ void py_module(py::module& module) { py::arg("mesh_shape"), py::arg("config")); module.def( +<<<<<<< HEAD "shard_tensor_to_2d_mesh_mapper", [](MeshDevice& mesh_device, const std::tuple mesh_shape, @@ -896,6 +915,8 @@ void py_module(py::module& module) { TensorToMesh: The created ShardTensor2dMesh mapper. )doc"); module.def( +======= +>>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors "concat_mesh_to_tensor_composer", [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, py::arg("dim")); @@ -906,6 +927,7 @@ void py_module(py::module& module) { }, py::arg("mesh_device"), py::arg("config")); +<<<<<<< HEAD module.def( "concat_2d_mesh_to_tensor_composer", [](MeshDevice& mesh_device, @@ -960,6 +982,8 @@ void py_module(py::module& module) { //TODO: overload this method to enable selection of a subset of shards with a config or something before passing to aggregate >>>>>>> expose classes to python ======= +======= +>>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to // aggregate >>>>>>> one type error left diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index a381564fe66..ee8dc7a4863 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -13,6 +13,12 @@ #include "ttnn/tensor/xtensor/partition.hpp" #include #include +#include "ttnn/distributed/api.hpp" +#include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/distributed/types.hpp" +#include "ttnn/tensor/xtensor/partition.hpp" +#include +#include namespace ttnn::distributed { @@ -45,6 +51,8 @@ class ReplicateTensorToMesh : public TensorToMesh { public: ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {} + ReplicateTensorToMesh(MeshDevice& mesh_device) : num_devices_(mesh_device.num_devices()) {} + std::vector map(const Tensor& tensor) const override { std::vector tensors; tensors.reserve(num_devices_); @@ -64,6 +72,8 @@ class ShardTensorToMesh : public TensorToMesh { public: ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} + ShardTensorToMesh(MeshDevice& mesh_device, int dim) : num_devices_(mesh_device.num_devices()), shard_dim_(dim) {} + std::vector map(const Tensor& tensor) const override { return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); } @@ -79,6 +89,17 @@ class ShardTensorToMesh : public TensorToMesh { class ShardTensorTo2dMesh : public TensorToMesh { public: + ShardTensorTo2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : + mesh_shape_(mesh_shape), config_(config) { + TT_FATAL( + config.row_dim.has_value() || config.col_dim.has_value(), + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); + TT_FATAL( + mesh_shape.num_rows <= mesh_device.shape().num_rows && // + mesh_shape.num_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + } + ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) {} @@ -188,6 +209,7 @@ std::unique_ptr concat_mesh_to_tensor_composer(int dim); // Creates a composer that concatenates a tensor across two dimensions. + std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh_device, const Concat2dConfig& config); // Distributes a host tensor onto multi-device configuration according to the `mapper`. From 935d2e5ca86ae1f75c9a9fdfa3ed37e71fdba802 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Mon, 10 Feb 2025 22:54:24 +0000 Subject: [PATCH 25/76] fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice --- .../distributed/test_distributed_tensor.py | 156 +++---- .../ttnn/distributed/distributed_pybind.cpp | 34 +- .../ttnn/distributed/distributed_tensor.cpp | 8 + .../ttnn/distributed/distributed_tensor.hpp | 12 +- ttnn/ttnn/__init__.py | 20 +- ttnn/ttnn/distributed/distributed.py | 420 +++++++++--------- 6 files changed, 310 insertions(+), 340 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index b31f411fa27..7248bf0cf63 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -1,41 +1,25 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 import torch +import typing import pytest import ttnn +import tempfile from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc -from models.utility_functions import nearest_32 - -@pytest.mark.parametrize( - "mesh_device", - [ - 32, - ], - indirect=True, +from ttnn import ( + ShardTensorToMesh, + ShardTensor2dMesh, + ReplicateTensorToMesh, + ConcatMeshToTensor, + ConcatMesh2dToTensor, + MeshToTensor, + TensorToMesh, ) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): - torch.manual_seed(1234) - - torch_tensor = torch.randn(1, 1, 32, 256) - to_repl = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - ) - - mapper = ttnn.CppReplicateTensorToMesh(mesh_device) - replicated_tensors = ttnn.from_torch(to_repl, mapper, mesh_device) - out_tensors = ttnn.get_device_tensors(replicated_tensors) - - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +from models.utility_functions import nearest_32 @pytest.mark.parametrize( @@ -49,92 +33,66 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) - to_repl = ttnn.from_torch( + torch_tensor = torch.randn(1, 1, 32, 8192) + replicated = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, + mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) - replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) - out_tensors = ttnn.get_device_tensors(replicated_tensors) + out_tensors = ttnn.get_device_tensors(mesh_device) + + test = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensors[0], test, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) - to_shard = ttnn.from_torch( + torch_tensor = torch.randn(1, 1, 8192, 32768) + tensor_shards = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), ) - mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) - shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + out_tensors = ttnn.get_device_tensors(tensor_shards) - out_tensor = ttnn.aggregate_as_tensor(shards) + out_tensor = ttnn.aggregate_as_tensor(out_tensors) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) - to_shard = ttnn.from_torch( + torch_tensor = torch.randn(1, 1, 8192, 32768) + sharded = ttnn.from_torch( torch_tensor, - dtype=dtype, + dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), ) - mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - - composer = ttnn.concat_mesh_to_tensor_composer(dim=3) + test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) - out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass - - -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_concat_slice_to_tensor(mesh_device, dtype): - torch.manual_seed(1234) - - torch_tensor = torch.randn(1, 1, 32, 256) - to_shard = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, + out_tensor = ttnn.to_torch( + torch_tensor, dtype=ttnn.bfloat16, mesh_composer=ttnn.ConcatMeshToTensor(dim=3), device=mesh_device ) - mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - - composer = ttnn.concat_mesh_to_tensor_composer(dim=3) - - sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) - - shards = ttnn.get_device_tensors(sharded_tensor) - - out_tensor = ttnn.aggregate_tensor(shards, composer) - - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -144,9 +102,8 @@ def test_concat_slice_to_tensor(mesh_device, dtype): ) @pytest.mark.parametrize( "M, K, N", - [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], + [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], ) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) @@ -154,7 +111,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) + shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] @@ -167,33 +124,25 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - to_shard = ttnn.from_torch( + tensor_shards = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), ) - mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - - shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + out_tensors = ttnn.get_device_tensors(tensor_shards) - ttnn.aggregate_as_tensor(shards) + for tensor in out_tensors: + print(tensor) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass + # out_pass, out_pcc = comp_pcc(tensor_shards, out, pcc=0.99) + # logger.info(f"PCC value: {out_pcc}") + # assert out_pass -@pytest.mark.parametrize( - "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] -) -@pytest.mark.parametrize( - "M, K, N", - [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], -) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) @@ -201,8 +150,8 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) - concat_dim = (3, 1) if K < N else (1, 3) + shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + concat_dim = (3, 1) if K < N else (1, 3) # dim 1 for reduce, dim 3 for concatenating fractures K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] @@ -215,20 +164,19 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - to_shard = ttnn.from_torch( + tensor_shards = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, + mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), ) - mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - - composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) - - out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) + out = ttnn.to_torch( + tensor_shards, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + ) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(out, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 49198ddc3af..3a79c671813 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -123,11 +123,16 @@ void py_module_types(py::module& module) { py::class_>( module, "ReplicateTensorToMesh"); py::class_>(module, "ShardTensorToMesh"); - py::class_>(module, "ShardTensorTo2dMesh"); + py::class_>(module, "ShardTensor2dMesh"); py::class_>(module, "ConcatMeshToTensor"); +<<<<<<< HEAD py::class_>( module, "Concat2dMeshToTensor"); >>>>>>> one type error left +======= + py::class_>( + module, "ConcatMesh2dToTensor"); +>>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); @@ -660,16 +665,15 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = - static_cast>>( - module.attr("ShardTensorTo2dMesh")); + auto py_shard_tensor_to_2d_mesh = static_cast>>( + module.attr("ShardTensor2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, const MeshShape& mesh_shape, - const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensorTo2dMesh(mesh_device, mesh_shape, config)); + const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); }), py::kw_only(), py::arg("mesh_device"), @@ -677,17 +681,17 @@ void py_module(py::module& module) { py::arg("config")) .def( py::init( - [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensorTo2dMesh(mesh_shape, config)); + [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { + return std::make_unique(ShardTensor2dMesh(mesh_shape, config)); }), py::kw_only(), py::arg("mesh_shape"), py::arg("config")) .def( "map", - [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) - .def("config", &ShardTensorTo2dMesh::config); + .def("config", &ShardTensor2dMesh::config); auto py_mesh_to_tensor = static_cast>>( module.attr("MeshToTensor")); @@ -710,8 +714,8 @@ void py_module(py::module& module) { py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("Concat2dMeshToTensor")); + static_cast>>( + module.attr("ConcatMesh2dToTensor")); py_concat_2d_mesh_to_tensor <<<<<<< HEAD .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { @@ -728,20 +732,20 @@ void py_module(py::module& module) { ======= .def( py::init( - [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { + [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { TT_FATAL( config.row_dim != config.col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", config.row_dim, config.col_dim); - return std::make_unique(mesh_device, config); + return std::make_unique(mesh_device, config); }), py::kw_only(), py::arg("mesh_device"), py::arg("config")) .def( "compose", - [](Concat2dMeshToTensor self, const std::vector& tensors) -> Tensor { + [](ConcatMesh2dToTensor self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 37556c7324d..dccae79d68f 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -162,6 +162,7 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( mesh_shape[1] <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD return std::make_unique(mesh_shape[0], mesh_shape[1], config); ======= @@ -170,6 +171,9 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( ======= return std::make_unique(mesh_shape, config); >>>>>>> add back distributed.py for now, clean up class overloads +======= + return std::make_unique(mesh_shape, config); +>>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice } std::unique_ptr concat_mesh_to_tensor_composer(int dim) { @@ -183,6 +187,7 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh config.row_dim, config.col_dim); <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD TT_FATAL(mesh_device.shape().dims() == 2, "Mesh device is not configured as a 2D mesh: {}", mesh_device.shape()); return std::make_unique(mesh_device.shape()[0], mesh_device.shape()[1], config); @@ -192,6 +197,9 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh ======= return std::make_unique(mesh_device, config); >>>>>>> add back distributed.py for now, clean up class overloads +======= + return std::make_unique(mesh_device, config); +>>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice } Tensor distribute_tensor( diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index ee8dc7a4863..7ae8d24cb5b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -87,9 +87,9 @@ class ShardTensorToMesh : public TensorToMesh { int shard_dim_ = -1; }; -class ShardTensorTo2dMesh : public TensorToMesh { +class ShardTensor2dMesh : public TensorToMesh { public: - ShardTensorTo2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : + ShardTensor2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) { TT_FATAL( config.row_dim.has_value() || config.col_dim.has_value(), @@ -100,7 +100,7 @@ class ShardTensorTo2dMesh : public TensorToMesh { "Device mesh shape does not match the provided mesh shape."); } - ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : + ShardTensor2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : mesh_shape_(mesh_shape), config_(config) {} std::vector map(const Tensor& tensor) const override { @@ -137,7 +137,7 @@ class ShardTensorTo2dMesh : public TensorToMesh { TT_FATAL( static_cast(tensor_shards.size()) == rows * cols, - "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh " "dimensions. Size: {}, rows: {}, cols: {}", tensor_shards.size(), rows, @@ -167,9 +167,9 @@ class ConcatMeshToTensor : public MeshToTensor { int concat_dim_ = -1; }; -class Concat2dMeshToTensor : public MeshToTensor { +class ConcatMesh2dToTensor : public MeshToTensor { public: - Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : + ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : mesh_shape_(mesh_device.shape()), config_(config) {} Tensor compose(const std::vector& tensors) const override { diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index dce198ccc88..01469e50cc3 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -96,19 +96,13 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( MeshDevice, - CppMeshToTensor, - CppTensorToMesh, - CppReplicateTensorToMesh, - CppShardTensorToMesh, - CppShardTensorTo2dMesh, - CppConcatMeshToTensor, - CppConcat2dMeshToTensor, - ReplicateTensor, - ShardTensor, - ShardTensor2d, - ShardMesh, - AllGatherTensor, - DistributedTensorConfig, + MeshToTensor, + TensorToMesh, + ReplicateTensorToMesh, + ShardTensorToMesh, + ShardTensor2dMesh, + ConcatMeshToTensor, + ConcatMesh2dToTensor, get_device_tensor, get_device_tensors, get_shard2d_config, diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index fa057bd0051..d59377181df 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -209,224 +209,240 @@ def synchronize_devices( ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) -# TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days -class TensorToMesh: - """ - Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. - You can also "Bring your own TensorToMesh" based on your custom mapping. - """ +# class TensorToMesh: +# """ +# Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. +# You can also "Bring your own TensorToMesh" based on your custom mapping. +# """ - def __init__(self, mesh_device): - self.mesh_device = mesh_device +# def __init__(self, mesh_device): +# self.mesh_device = mesh_device - def map(self, tensor: "torch.Tensor"): - raise NotImplementedError("Subclasses must implement this method") +# def map(self, tensor: "torch.Tensor"): +# raise NotImplementedError("Subclasses must implement this method") - def config(self): - raise NotImplementedError("Subclasses must implement this method") +# def config(self): +# raise NotImplementedError("Subclasses must implement this method") -class MeshToTensor: - """ - Defines the inverse operation of TensorToMesh. Given a set of per-device - ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines - the mapping back to one or many torch.Tensor objects. - You can also "Bring your own MeshToTensor" based on your custom mapping. - """ +# class MeshToTensor: +# """ +# Defines the inverse operation of TensorToMesh. Given a set of per-device +# ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines +# the mapping back to one or many torch.Tensor objects. - def compose(self, tensor: ttnn.Tensor): - raise NotImplementedError("Subclasses must implement this method") +# You can also "Bring your own MeshToTensor" based on your custom mapping. +# """ +# def compose(self, tensor: ttnn.Tensor): +# raise NotImplementedError("Subclasses must implement this method") -class ShardTensorToMesh(TensorToMesh): - def __init__(self, mesh_device, dim): - super().__init__(mesh_device) - self.shard_dim = dim - def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: - import torch +# class ShardTensorToMesh(TensorToMesh): +# def __init__(self, mesh_device, dim): +# super().__init__(mesh_device) +# self.shard_dim = dim - sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) - return list(sliced_tensors) +# def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: +# import torch - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.shard_dim}", - } +# sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) +# return list(sliced_tensors) +# def config(self): +# return { +# "strategy": "shard", +# "shard_dim": f"{self.shard_dim}", +# } -class ShardTensor2dMesh(TensorToMesh): - """ - Shard a tensor across a 2D mesh of devices. - This class implements a strategy for distributing a tensor across a 2D grid of devices, - allowing for efficient parallel processing in distributed computing environments. - """ - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): - """ - Initialize the ShardTensor2dMesh. - Args: - mesh_device: The target device mesh for distributing the tensor. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: The dimensions to shard along, specified as (row_dim, col_dim). - The `dims` tuple determines how the tensor is sharded across the 2D mesh: - - row_dim: The dimension to shard across mesh rows (or None for replication). - - col_dim: The dimension to shard across mesh columns (or None for replication). - Examples: - 1. dims=(2, 3) for a tensor of shape (A, B, C, D): - - Shard along dimension 2 (C) across mesh rows - - Shard along dimension 3 (D) across mesh columns - 2. dims=(None, 3): - - Replicate across mesh rows - - Shard along dimension 3 (D) across mesh columns - 3. dims=(None, None): - - Fully replicate the tensor across all devices - """ - super().__init__(mesh_device) - self.mesh_shape: Tuple[int, int] = mesh_shape - self.dims: Tuple[Optional[int], Optional[int]] = dims - - mesh_device_rows, mesh_device_cols = self.mesh_device.shape - if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: - raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - - def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: - """ - Map the input tensor to a list of sharded tensors. - Args: - tensor: The input tensor to be sharded. - Returns: - A list of sharded tensors, one for each device in the mesh. - Raises: - ValueError: If the number of sharding dimensions is not 2. - """ - import torch - - if len(self.dims) != 2: - raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims - - # Shard along rows - row_tensors = ( - [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) - ) - - # Shard along columns - if col_dim is None: - return [t.clone() for t in row_tensors for _ in range(cols)] - tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - - if len(tensor_shards) != rows * cols: - raise ValueError( - f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." - ) - - return tensor_shards - - def config(self) -> Dict[str, str]: - """ - Provide the configuration of the sharding strategy. - Returns: - A dictionary containing the sharding strategy and dimensions. - """ - return { - "strategy": "shard_2d", - "mesh_shape_y": str(self.mesh_shape[0]), - "mesh_shape_x": str(self.mesh_shape[1]), - } - - -class ConcatMesh2dToTensor(MeshToTensor): - """ - Concatenate tensors from a 2D mesh back into a single tensor. - This class implements the inverse operation of ShardTensor2dMesh, combining - sharded tensors from a 2D device mesh back into a single tensor. - """ +# class ShardTensor2dMesh(TensorToMesh): +# """ +# Shard a tensor across a 2D mesh of devices. + +# This class implements a strategy for distributing a tensor across a 2D grid of devices, +# allowing for efficient parallel processing in distributed computing environments. +# """ + +# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): +# """ +# Initialize the ShardTensor2dMesh. + +# Args: +# mesh_device: The target device mesh for distributing the tensor. +# mesh_shape: The shape of the 2D mesh as (rows, cols). +# dims: The dimensions to shard along, specified as (row_dim, col_dim). + +# The `dims` tuple determines how the tensor is sharded across the 2D mesh: +# - row_dim: The dimension to shard across mesh rows (or None for replication). +# - col_dim: The dimension to shard across mesh columns (or None for replication). + +# Examples: +# 1. dims=(2, 3) for a tensor of shape (A, B, C, D): +# - Shard along dimension 2 (C) across mesh rows +# - Shard along dimension 3 (D) across mesh columns + +# 2. dims=(None, 3): +# - Replicate across mesh rows +# - Shard along dimension 3 (D) across mesh columns + +# 3. dims=(None, None): +# - Fully replicate the tensor across all devices +# """ +# super().__init__(mesh_device) +# self.mesh_shape: Tuple[int, int] = mesh_shape +# self.dims: Tuple[Optional[int], Optional[int]] = dims + +# mesh_device_rows, mesh_device_cols = self.mesh_device.shape +# if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: +# raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") + +# def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: +# """ +# Map the input tensor to a list of sharded tensors. + +# Args: +# tensor: The input tensor to be sharded. + +# Returns: +# A list of sharded tensors, one for each device in the mesh. + +# Raises: +# ValueError: If the number of sharding dimensions is not 2. +# """ +# import torch + +# if len(self.dims) != 2: +# raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") + +# rows, cols = self.mesh_shape +# row_dim, col_dim = self.dims + +# # Shard along rows +# row_tensors = ( +# [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) +# ) + +# # Shard along columns +# if col_dim is None: +# return [t.clone() for t in row_tensors for _ in range(cols)] +# tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] + +# if len(tensor_shards) != rows * cols: +# raise ValueError( +# f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." +# ) + +# return tensor_shards + +# def config(self) -> Dict[str, str]: +# """ +# Provide the configuration of the sharding strategy. + +# Returns: +# A dictionary containing the sharding strategy and dimensions. +# """ +# return { +# "strategy": "shard_2d", +# "mesh_shape_y": str(self.mesh_shape[0]), +# "mesh_shape_x": str(self.mesh_shape[1]), +# } + + +# class ConcatMesh2dToTensor(MeshToTensor): +# """ +# Concatenate tensors from a 2D mesh back into a single tensor. + +# This class implements the inverse operation of ShardTensor2dMesh, combining +# sharded tensors from a 2D device mesh back into a single tensor. +# """ + +# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): +# """ +# Initialize the ConcatMesh2dToTensor. + +# Args: +# mesh_device: The source device mesh containing the sharded tensors. +# mesh_shape: The shape of the 2D mesh as (rows, cols). +# dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. +# The first element (row_dim) indicates the dimension for concatenating tensors from different rows. +# The second element (col_dim) indicates the dimension for concatenating tensors from different columns. +# Both dimensions must be specified and different from each other. +# These dimensions correspond to the tensor dimensions, not the mesh dimensions. +# For example, if the original tensor was 4D with shape (batch, channel, height, width), +# and it was sharded across height and width, dims might be (-2, -1) or (2, 3). + +# Raises: +# ValueError: If either dimension in 'dims' is None or if both dimensions are the same. +# """ +# self.mesh_device = mesh_device +# self.mesh_shape = mesh_shape +# self.dims = dims +# if self.dims[0] == self.dims[1]: +# raise ValueError("Both dimensions in 'dims' must be different") + +# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": +# """ +# Compose the sharded tensors back into a single tensor. + +# Args: +# tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. + +# Returns: +# A single torch.Tensor that combines all the sharded tensors from all devices. + +# This method first concatenates the shards along the column dimension within each row, +# then concatenates the resulting tensors along the row dimension to form the final tensor. +# """ +# import torch + +# device_shards = [ +# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) +# ] + +# rows, cols = self.mesh_shape +# row_dim, col_dim = self.dims + +# # Reshape the list of shards into a 2D list representing the device mesh +# mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] + +# # Concatenate along columns first (within each row) +# row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] + +# # Then concatenate the resulting tensors along rows +# return torch.cat(row_concatenated, dim=row_dim) + + +# class ReplicateTensorToMesh(TensorToMesh): +# def __init__(self, mesh_device: MeshDevice): +# super().__init__(mesh_device) + +# def map(self, tensor: "torch.Tensor"): +# return [tensor for i in range(self.mesh_device.get_num_devices())] + +# def config(self): +# return { +# "strategy": "replicate", +# "replication_factor": str(self.mesh_device.get_num_devices()), +# } + + +# class ConcatMeshToTensor(MeshToTensor): +# def __init__(self, mesh_device: MeshDevice, dim: int): +# self.concat_dim = dim +# self.mesh_device = mesh_device + +# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": +# import torch - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): - """ - Initialize the ConcatMesh2dToTensor. - Args: - mesh_device: The source device mesh containing the sharded tensors. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. - The first element (row_dim) indicates the dimension for concatenating tensors from different rows. - The second element (col_dim) indicates the dimension for concatenating tensors from different columns. - Both dimensions must be specified and different from each other. - These dimensions correspond to the tensor dimensions, not the mesh dimensions. - For example, if the original tensor was 4D with shape (batch, channel, height, width), - and it was sharded across height and width, dims might be (-2, -1) or (2, 3). - Raises: - ValueError: If either dimension in 'dims' is None or if both dimensions are the same. - """ - self.mesh_device = mesh_device - self.mesh_shape = mesh_shape - self.dims = dims - if self.dims[0] == self.dims[1]: - raise ValueError("Both dimensions in 'dims' must be different") - - def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": - """ - Compose the sharded tensors back into a single tensor. - Args: - tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. - Returns: - A single torch.Tensor that combines all the sharded tensors from all devices. - This method first concatenates the shards along the column dimension within each row, - then concatenates the resulting tensors along the row dimension to form the final tensor. - """ - import torch - - device_shards = [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) - ] - - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims - - # Reshape the list of shards into a 2D list representing the device mesh - mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] - - # Concatenate along columns first (within each row) - row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] - - # Then concatenate the resulting tensors along rows - return torch.cat(row_concatenated, dim=row_dim) - - -class ReplicateTensorToMesh(TensorToMesh): - def __init__(self, mesh_device: MeshDevice): - super().__init__(mesh_device) - - def map(self, tensor: "torch.Tensor"): - return [tensor for i in range(self.mesh_device.get_num_devices())] - - def config(self): - return { - "strategy": "replicate", - "replication_factor": str(self.mesh_device.get_num_devices()), - } - - -class ConcatMeshToTensor(MeshToTensor): - def __init__(self, mesh_device: MeshDevice, dim: int): - self.concat_dim = dim - self.mesh_device = mesh_device - - def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": - import torch - - device_shards_converted_to_torch = [ - ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) - ] - return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) +# device_shards_converted_to_torch = [ +# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) +# ] +# return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) @contextlib.contextmanager -def distribute(default: Union[TensorToMesh, MeshToTensor]): +def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): """ Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. @@ -449,9 +465,9 @@ def distribute(default: Union[TensorToMesh, MeshToTensor]): _original_from_torch = ttnn.from_torch try: - if isinstance(default, TensorToMesh): + if isinstance(default, ttnn.TensorToMesh): ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, MeshToTensor): + elif isinstance(default, ttnn.MeshToTensor): ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) else: raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") From d02a0fbde50cb5523aa38de9df5d1c5054420be2 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 19:32:13 +0000 Subject: [PATCH 26/76] fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error --- .../distributed/test_distributed_tensor.py | 46 ++++++++++++------- .../ttnn/distributed/distributed_pybind.cpp | 39 ++++++++++++---- ttnn/ttnn/__init__.py | 6 +++ ttnn/ttnn/distributed/__init__.py | 1 - 4 files changed, 67 insertions(+), 25 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 7248bf0cf63..c766e7a63fc 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -34,15 +34,14 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) torch_tensor = torch.randn(1, 1, 32, 8192) - replicated = ttnn.from_torch( + to_repl = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), ) - out_tensors = ttnn.get_device_tensors(mesh_device) + out_tensors = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) test = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device) @@ -51,6 +50,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): assert out_pass +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) @@ -60,12 +60,11 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), ) test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) - out_tensors = ttnn.get_device_tensors(tensor_shards) + out_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(tensor_shards) out_tensor = ttnn.aggregate_as_tensor(out_tensors) @@ -74,23 +73,23 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): assert out_pass +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) torch_tensor = torch.randn(1, 1, 8192, 32768) - sharded = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), ) test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) - out_tensor = ttnn.to_torch( - torch_tensor, dtype=ttnn.bfloat16, mesh_composer=ttnn.ConcatMeshToTensor(dim=3), device=mesh_device - ) + sharded_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + + out_tensor = ttnn.to_torch(ttnn.ConcatMeshToTensor(dim=3).compose(), dtype=ttnn.bfloat16, device=mesh_device) out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) logger.info(f"PCC value: {out_pcc}") @@ -104,6 +103,7 @@ def test_concat_to_tensor(mesh_device, dtype): "M, K, N", [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], ) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) @@ -124,16 +124,17 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - tensor_shards = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), ) - out_tensors = ttnn.get_device_tensors(tensor_shards) + out_tensors = ttnn.get_device_tensors( + ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) + ) for tensor in out_tensors: print(tensor) @@ -143,6 +144,14 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): # assert out_pass +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) @@ -164,17 +173,22 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - tensor_shards = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim), + ) + + sharded_tensors = ttnn.get_device_tensors( + ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) ) out = ttnn.to_torch( - tensor_shards, mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape).compose( + sharded_tensors + ), ) out_pass, out_pcc = comp_pcc(out, torch_tensor, pcc=0.99) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 3a79c671813..492d67b65b0 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -134,6 +134,13 @@ void py_module_types(py::module& module) { module, "ConcatMesh2dToTensor"); >>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice + py::class_(module, "ReplicateTensor"); + py::class_(module, "ShardTensor"); + py::class_(module, "ShardTensor2D"); + py::class_(module, "ShardMesh"); + py::class_(module, "AllGatherTensor"); + py::class_(module, "DistributedTensorConfig"); + py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); py::class_(module, "MeshShape", "Shape of a mesh device."); @@ -628,13 +635,11 @@ void py_module(py::module& module) { py::init([](MeshDevice& mesh_device) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); }), - py::kw_only(), py::arg("mesh_device")) .def( py::init([](size_t num_devices) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(num_devices)); }), - py::kw_only(), py::arg("num_devices")) .def( "map", @@ -649,14 +654,12 @@ void py_module(py::module& module) { py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { return std::make_unique(ShardTensorToMesh(mesh_device, dim)); }), - py::kw_only(), py::arg("mesh_device"), py::arg("dim")) .def( py::init([](size_t num_devices, int dim) -> std::unique_ptr { return std::make_unique(ShardTensorToMesh(num_devices, dim)); }), - py::kw_only(), py::arg("num_devices"), py::arg("dim")) .def( @@ -675,7 +678,6 @@ void py_module(py::module& module) { const Shard2dConfig& config) -> std::unique_ptr { return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); }), - py::kw_only(), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")) @@ -684,7 +686,6 @@ void py_module(py::module& module) { [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { return std::make_unique(ShardTensor2dMesh(mesh_shape, config)); }), - py::kw_only(), py::arg("mesh_shape"), py::arg("config")) .def( @@ -706,7 +707,6 @@ void py_module(py::module& module) { py::init([](int dim) -> std::unique_ptr { return std::make_unique(dim); }), - py::kw_only(), py::arg("dim")) .def( "compose", @@ -740,7 +740,6 @@ void py_module(py::module& module) { config.col_dim); return std::make_unique(mesh_device, config); }), - py::kw_only(), py::arg("mesh_device"), py::arg("config")) .def( @@ -866,10 +865,14 @@ void py_module(py::module& module) { module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); <<<<<<< HEAD <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD // TODO: Add rdocs ======= >>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors +======= + // TODO: Add rdocs +>>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error module.def( "replicate_tensor_to_mesh_mapper", [](MeshDevice& mesh_device) -> std::unique_ptr { @@ -931,6 +934,7 @@ void py_module(py::module& module) { }, py::arg("mesh_device"), py::arg("config")); +<<<<<<< HEAD <<<<<<< HEAD module.def( "concat_2d_mesh_to_tensor_composer", @@ -958,11 +962,14 @@ void py_module(py::module& module) { Returns: TensorToMesh: The created ConcatMesh2dToTensor composer. )doc"); +======= +>>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error module.def( "distribute_tensor", [](const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device) -> Tensor { +<<<<<<< HEAD return distribute_tensor(from_device(tensor), mapper, mesh_device); }, py::arg("tensor"), @@ -973,11 +980,21 @@ void py_module(py::module& module) { [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(from_device(tensor), composer); }, +======= + return distribute_tensor(tensor, mapper, mesh_device); + }, + py::arg("tensor"), + py::arg("mapper")); + module.def( + "aggregate_tensor", + [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, +>>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error py::arg("tensor"), py::arg("composer")); module.def( "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { +<<<<<<< HEAD return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); }, py::arg("tensor"), @@ -991,6 +1008,12 @@ void py_module(py::module& module) { // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to // aggregate >>>>>>> one type error left +======= + return aggregate_tensor(aggregate_as_tensor(tensors, AllGatherTensor{}), composer); + }, + py::arg("tensor"), + py::arg("composer")); +>>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 01469e50cc3..312eeb61551 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -103,6 +103,12 @@ def manage_config(name, value): ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor, + ReplicateTensor, + ShardTensor, + ShardTensor2d, + ShardMesh, + AllGatherTensor, + DistributedTensorConfig, get_device_tensor, get_device_tensors, get_shard2d_config, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 4901c6ae8cb..c1fa3c25670 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -4,7 +4,6 @@ # TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days from .distributed import ( - MeshDevice, DispatchCoreType, TensorToMesh, ShardTensorToMesh, From 9afad34711c67a94c8808b6a0c3ddf88c450f1f8 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 19:38:34 +0000 Subject: [PATCH 27/76] add aggregate/distribute imports to init --- ttnn/ttnn/__init__.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 312eeb61551..cb71132c19d 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -115,11 +115,6 @@ def manage_config(name, value): get_concat2d_config, get_distributed_tensor_config, aggregate_as_tensor, - replicate_tensor_to_mesh_mapper, - shard_tensor_to_mesh_mapper, - shard_tensor_to_2d_mesh_mapper, - concat_mesh_to_tensor_composer, - concat_2d_mesh_to_tensor_composer, aggregate_tensor, distribute_tensor, get_t3k_physical_device_ids_ring, From 1071396fa10a6911510451764fd89b6c77a5eb31 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 20:58:21 +0000 Subject: [PATCH 28/76] add configs to pybind --- .../ttnn/distributed/distributed_pybind.cpp | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 492d67b65b0..58dd4f388f4 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -782,7 +782,11 @@ void py_module(py::module& module) { auto py_replicate_tensor_config = static_cast>(module.attr("ShardTensor")); py_replicate_tensor_config.def(py::init<>()) .def(py::init(), py::arg("replication_factor") = 1) +<<<<<<< HEAD .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) +======= + .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) +>>>>>>> add configs to pybind .def("__eq__", [](const ReplicateTensor& a, const ReplicateTensor& b) { return a.replication_factor == b.replication_factor; }); @@ -791,6 +795,7 @@ void py_module(py::module& module) { py_shard_tensor_config.def(py::init(), py::arg("shard_dimension")) .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) .def("__eq__", [](const ShardTensor& a, const ShardTensor& b) { return a == b; }); +<<<<<<< HEAD auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); @@ -810,6 +815,19 @@ void py_module(py::module& module) { py_concat2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) .def_readwrite("row_dim", &Concat2dConfig::row_dim) .def_readwrite("col_dim", &Concat2dConfig::col_dim); +======= + + auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); + py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); + + auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2D")); + py_shard_tensor2d.def(py::init(), py::arg("mesh")) + .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) + .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); + + auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")); + .def(py::init<>()).def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); +>>>>>>> add configs to pybind module.def( "get_distributed_tensor_config", @@ -823,6 +841,7 @@ void py_module(py::module& module) { "item": "field", } )doc"); +<<<<<<< HEAD module.def( "get_shard2d_config", &get_shard2d_config, @@ -845,6 +864,8 @@ void py_module(py::module& module) { "col_dim": "field", } )doc"); +======= +>>>>>>> add configs to pybind module.def( "get_device_tensor", py::overload_cast(&ttnn::distributed::get_device_tensor), @@ -984,7 +1005,8 @@ void py_module(py::module& module) { return distribute_tensor(tensor, mapper, mesh_device); }, py::arg("tensor"), - py::arg("mapper")); + py::arg("mapper"), + py::arg("mesh_device")); module.def( "aggregate_tensor", [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, From 7f54f90541c19437e3d444f21058adb51f5c7842 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 21:04:47 +0000 Subject: [PATCH 29/76] change test cases to use distribute/aggregate --- .../distributed/test_distributed_tensor.py | 85 ++++++++++++------- 1 file changed, 53 insertions(+), 32 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index c766e7a63fc..903cbf0a23c 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -11,6 +11,8 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc from ttnn import ( + distribute_tensor, + aggregate_tensor, ShardTensorToMesh, ShardTensor2dMesh, ReplicateTensorToMesh, @@ -41,11 +43,11 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - out_tensors = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) + mapper = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) + replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) + out_tensors = ttnn.get_device_tensors(replicated_tensors) - test = ttnn.from_torch(torch_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device) - - out_pass, out_pcc = comp_pcc(out_tensors[0], test, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -55,20 +57,18 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) torch_tensor = torch.randn(1, 1, 8192, 32768) - tensor_shards = ttnn.from_torch( + to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, ) - test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) - out_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(tensor_shards) + out_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) - out_tensor = ttnn.aggregate_as_tensor(out_tensors) - - out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -80,18 +80,44 @@ def test_concat_to_tensor(mesh_device, dtype): torch_tensor = torch.randn(1, 1, 8192, 32768) to_shard = ttnn.from_torch( torch_tensor, - dtype=ttnn.bfloat16, + dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, ) - test = ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) - sharded_tensors = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + composer = ttnn.ConcatMeshToTensor(dim=3) - out_tensor = ttnn.to_torch(ttnn.ConcatMeshToTensor(dim=3).compose(), dtype=ttnn.bfloat16, device=mesh_device) + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out_tensor, test, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat_slice_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 8192, 32768) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + + composer = ttnn.ConcatMeshToTensor(dim=3) + + out_tensor = [] + out_tensor[0] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-2], composer) + out_tensor[1] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-1], composer) + out_tensor[2] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:0], composer) + + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -132,16 +158,15 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - out_tensors = ttnn.get_device_tensors( - ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) - ) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) + + out_tensors = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - for tensor in out_tensors: - print(tensor) + ttnn.aggregate_as_tensor(out_tensors, mesh_device) - # out_pass, out_pcc = comp_pcc(tensor_shards, out, pcc=0.99) - # logger.info(f"PCC value: {out_pcc}") - # assert out_pass + out_pass, out_pcc = comp_pcc(out_tensors, torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass @pytest.mark.parametrize( @@ -181,16 +206,12 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - sharded_tensors = ttnn.get_device_tensors( - ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) - ) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - out = ttnn.to_torch( - mesh_composer=ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape).compose( - sharded_tensors - ), - ) + composer = ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass From e9d21c58e4a85daac33b54d233da881deb94b753 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 14 Feb 2025 22:58:12 +0000 Subject: [PATCH 30/76] fix test mappers, convert to cpu_tensor --- .../distributed/test_distributed_tensor.py | 8 ++-- .../ttnn/distributed/distributed_pybind.cpp | 41 ++++++++++++++++++- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 903cbf0a23c..5abb2c0d690 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -43,7 +43,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ReplicateTensorToMesh(mesh_device).map(to_repl) + mapper = ttnn.ReplicateTensorToMesh(mesh_device) replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) out_tensors = ttnn.get_device_tensors(replicated_tensors) @@ -64,7 +64,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) out_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) @@ -85,7 +85,7 @@ def test_concat_to_tensor(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3).map(to_shard) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) composer = ttnn.ConcatMeshToTensor(dim=3) @@ -158,7 +158,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim).map(to_shard) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) out_tensors = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 58dd4f388f4..800da6df25e 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -26,6 +26,7 @@ #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" #include "tt-metalium/assert.hpp" +<<<<<<< HEAD ======= #include "distributed_tensor.hpp" <<<<<<< HEAD @@ -47,8 +48,11 @@ #include "distributed_tensor.cpp" ======= >>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors +======= +>>>>>>> fix test mappers, convert to cpu_tensor #include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/operations/core/core.hpp" #include "ttnn/tensor/tensor_utils.hpp" >>>>>>> one type error left #include "ttnn/tensor/tensor.hpp" @@ -84,6 +88,14 @@ struct ConcreteMeshToTensor : MeshToTensor { } }; +Tensor get_cpu_tensor(const Tensor& tensor) { + if (is_device_tensor(tensor)) { + Tensor cpu_tensor = tensor.cpu(); + TT_ASSERT(is_device_tensor(cpu_tensor)); + } + return tensor; +} + void py_module_types(py::module& module) { <<<<<<< HEAD <<<<<<< HEAD @@ -136,7 +148,7 @@ void py_module_types(py::module& module) { py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); - py::class_(module, "ShardTensor2D"); + py::class_(module, "ShardTensor2d"); py::class_(module, "ShardMesh"); py::class_(module, "AllGatherTensor"); py::class_(module, "DistributedTensorConfig"); @@ -782,11 +794,15 @@ void py_module(py::module& module) { auto py_replicate_tensor_config = static_cast>(module.attr("ShardTensor")); py_replicate_tensor_config.def(py::init<>()) .def(py::init(), py::arg("replication_factor") = 1) +<<<<<<< HEAD <<<<<<< HEAD .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) ======= .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) >>>>>>> add configs to pybind +======= + .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) +>>>>>>> fix test mappers, convert to cpu_tensor .def("__eq__", [](const ReplicateTensor& a, const ReplicateTensor& b) { return a.replication_factor == b.replication_factor; }); @@ -820,14 +836,21 @@ void py_module(py::module& module) { auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); - auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2D")); + auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); py_shard_tensor2d.def(py::init(), py::arg("mesh")) .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); +<<<<<<< HEAD auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")); .def(py::init<>()).def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); >>>>>>> add configs to pybind +======= + auto py_allgather_config = + static_cast>(module.attr("AllGatherTensor")) + .def(py::init<>()) + .def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); +>>>>>>> fix test mappers, convert to cpu_tensor module.def( "get_distributed_tensor_config", @@ -990,6 +1013,7 @@ void py_module(py::module& module) { [](const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device) -> Tensor { +<<<<<<< HEAD <<<<<<< HEAD return distribute_tensor(from_device(tensor), mapper, mesh_device); }, @@ -1003,19 +1027,29 @@ void py_module(py::module& module) { }, ======= return distribute_tensor(tensor, mapper, mesh_device); +======= + return distribute_tensor(get_cpu_tensor(tensor), mapper, mesh_device); +>>>>>>> fix test mappers, convert to cpu_tensor }, py::arg("tensor"), py::arg("mapper"), py::arg("mesh_device")); module.def( "aggregate_tensor", +<<<<<<< HEAD [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, >>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error +======= + [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { + return aggregate_tensor(get_cpu_tensor(tensor), composer); + }, +>>>>>>> fix test mappers, convert to cpu_tensor py::arg("tensor"), py::arg("composer")); module.def( "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { +<<<<<<< HEAD <<<<<<< HEAD return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); }, @@ -1032,6 +1066,9 @@ void py_module(py::module& module) { >>>>>>> one type error left ======= return aggregate_tensor(aggregate_as_tensor(tensors, AllGatherTensor{}), composer); +======= + return aggregate_tensor(get_cpu_tensor(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); +>>>>>>> fix test mappers, convert to cpu_tensor }, py::arg("tensor"), py::arg("composer")); From 670de836a20ef2e7d4c927cb181afb01f8e3a4ea Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:11:35 +0000 Subject: [PATCH 31/76] clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu --- .../distributed/test_distributed_tensor.py | 59 ++++++++----------- .../ttnn/distributed/distributed_pybind.cpp | 50 +++++++++++++--- ttnn/ttnn/__init__.py | 5 ++ 3 files changed, 70 insertions(+), 44 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 5abb2c0d690..3818a304470 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -3,24 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 import torch -import typing import pytest import ttnn -import tempfile from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc - -from ttnn import ( - distribute_tensor, - aggregate_tensor, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, - ConcatMeshToTensor, - ConcatMesh2dToTensor, - MeshToTensor, - TensorToMesh, -) from models.utility_functions import nearest_32 @@ -43,7 +29,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ReplicateTensorToMesh(mesh_device) + mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) out_tensors = ttnn.get_device_tensors(replicated_tensors) @@ -64,11 +50,13 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - out_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) + out_tensor = ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -85,13 +73,13 @@ def test_concat_to_tensor(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - composer = ttnn.ConcatMeshToTensor(dim=3) + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -108,16 +96,17 @@ def test_concat_slice_to_tensor(mesh_device, dtype): device=mesh_device, ) - mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) + + sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) - composer = ttnn.ConcatMeshToTensor(dim=3) + shards = ttnn.get_device_tensors(sharded_tensor) - out_tensor = [] - out_tensor[0] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-2], composer) - out_tensor[1] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:-1], composer) - out_tensor[2] = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device)[:0], composer) + out_tensor = ttnn.aggregate_tensor(shards, composer) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -158,13 +147,13 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - out_tensors = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - ttnn.aggregate_as_tensor(out_tensors, mesh_device) + ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(out_tensors, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -206,12 +195,12 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - composer = ttnn.ConcatMesh2dToTensor(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(out_tensor, torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 800da6df25e..9a5107b65f6 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -6,6 +6,7 @@ <<<<<<< HEAD <<<<<<< HEAD #include +<<<<<<< HEAD <<<<<<< HEAD #include @@ -17,6 +18,14 @@ #include #include >>>>>>> expose classes to python +======= +<<<<<<< HEAD +#include +#include +#include +======= +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu ======= #include @@ -40,6 +49,7 @@ #include "ttnn/distributed/types.hpp" #include "ttnn/operations/core/core.hpp" <<<<<<< HEAD +<<<<<<< HEAD #include "ttnn/tensor/tensor_utils.hpp" >>>>>>> one type error left ======= @@ -57,6 +67,14 @@ >>>>>>> one type error left #include "ttnn/tensor/tensor.hpp" #include +======= +#include "ttnn/tensor/tensor.hpp" +<<<<<<< HEAD +#include "ttnn/types.hpp" +======= +#include +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -88,14 +106,6 @@ struct ConcreteMeshToTensor : MeshToTensor { } }; -Tensor get_cpu_tensor(const Tensor& tensor) { - if (is_device_tensor(tensor)) { - Tensor cpu_tensor = tensor.cpu(); - TT_ASSERT(is_device_tensor(cpu_tensor)); - } - return tensor; -} - void py_module_types(py::module& module) { <<<<<<< HEAD <<<<<<< HEAD @@ -898,6 +908,11 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device. +<<<<<<< HEAD +======= +<<<<<<< HEAD + +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu Args: tensor (Tensor): The tensor to get the shard from. device (Device): The device to get the shard for. @@ -906,6 +921,15 @@ void py_module(py::module& module) { Returns: Tensor: The shard of the tensor corresponding to the device. )doc"); +======= + Args: + tensor (Tensor): The tensor to get the shard from. + device (Device): The device to get the shard for. +aggregate_as + Returns: + Tensor: The shard of the tensor corresponding to the device. + )doc"); +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); <<<<<<< HEAD <<<<<<< HEAD @@ -1014,6 +1038,7 @@ void py_module(py::module& module) { const TensorToMesh& mapper, std::optional> mesh_device) -> Tensor { <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD return distribute_tensor(from_device(tensor), mapper, mesh_device); }, @@ -1030,6 +1055,9 @@ void py_module(py::module& module) { ======= return distribute_tensor(get_cpu_tensor(tensor), mapper, mesh_device); >>>>>>> fix test mappers, convert to cpu_tensor +======= + return distribute_tensor(from_device(tensor), mapper, mesh_device); +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu }, py::arg("tensor"), py::arg("mapper"), @@ -1041,7 +1069,7 @@ void py_module(py::module& module) { >>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error ======= [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { - return aggregate_tensor(get_cpu_tensor(tensor), composer); + return aggregate_tensor(from_device(tensor), composer); }, >>>>>>> fix test mappers, convert to cpu_tensor py::arg("tensor"), @@ -1050,6 +1078,7 @@ void py_module(py::module& module) { "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); }, @@ -1069,6 +1098,9 @@ void py_module(py::module& module) { ======= return aggregate_tensor(get_cpu_tensor(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); >>>>>>> fix test mappers, convert to cpu_tensor +======= + return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); +>>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu }, py::arg("tensor"), py::arg("composer")); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index cb71132c19d..312eeb61551 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -115,6 +115,11 @@ def manage_config(name, value): get_concat2d_config, get_distributed_tensor_config, aggregate_as_tensor, + replicate_tensor_to_mesh_mapper, + shard_tensor_to_mesh_mapper, + shard_tensor_to_2d_mesh_mapper, + concat_mesh_to_tensor_composer, + concat_2d_mesh_to_tensor_composer, aggregate_tensor, distribute_tensor, get_t3k_physical_device_ids_ring, From 1d1ff5af168fd5101a9580c55228fc8c1f68501f Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:15:29 +0000 Subject: [PATCH 32/76] remove python implementations --- ttnn/ttnn/distributed/distributed.py | 232 --------------------------- 1 file changed, 232 deletions(-) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index d59377181df..b29089e72f6 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -209,238 +209,6 @@ def synchronize_devices( ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) -# class TensorToMesh: -# """ -# Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. -# You can also "Bring your own TensorToMesh" based on your custom mapping. -# """ - -# def __init__(self, mesh_device): -# self.mesh_device = mesh_device - -# def map(self, tensor: "torch.Tensor"): -# raise NotImplementedError("Subclasses must implement this method") - -# def config(self): -# raise NotImplementedError("Subclasses must implement this method") - - -# class MeshToTensor: -# """ -# Defines the inverse operation of TensorToMesh. Given a set of per-device -# ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines -# the mapping back to one or many torch.Tensor objects. - -# You can also "Bring your own MeshToTensor" based on your custom mapping. -# """ - -# def compose(self, tensor: ttnn.Tensor): -# raise NotImplementedError("Subclasses must implement this method") - - -# class ShardTensorToMesh(TensorToMesh): -# def __init__(self, mesh_device, dim): -# super().__init__(mesh_device) -# self.shard_dim = dim - -# def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: -# import torch - -# sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) -# return list(sliced_tensors) - -# def config(self): -# return { -# "strategy": "shard", -# "shard_dim": f"{self.shard_dim}", -# } - - -# class ShardTensor2dMesh(TensorToMesh): -# """ -# Shard a tensor across a 2D mesh of devices. - -# This class implements a strategy for distributing a tensor across a 2D grid of devices, -# allowing for efficient parallel processing in distributed computing environments. -# """ - -# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): -# """ -# Initialize the ShardTensor2dMesh. - -# Args: -# mesh_device: The target device mesh for distributing the tensor. -# mesh_shape: The shape of the 2D mesh as (rows, cols). -# dims: The dimensions to shard along, specified as (row_dim, col_dim). - -# The `dims` tuple determines how the tensor is sharded across the 2D mesh: -# - row_dim: The dimension to shard across mesh rows (or None for replication). -# - col_dim: The dimension to shard across mesh columns (or None for replication). - -# Examples: -# 1. dims=(2, 3) for a tensor of shape (A, B, C, D): -# - Shard along dimension 2 (C) across mesh rows -# - Shard along dimension 3 (D) across mesh columns - -# 2. dims=(None, 3): -# - Replicate across mesh rows -# - Shard along dimension 3 (D) across mesh columns - -# 3. dims=(None, None): -# - Fully replicate the tensor across all devices -# """ -# super().__init__(mesh_device) -# self.mesh_shape: Tuple[int, int] = mesh_shape -# self.dims: Tuple[Optional[int], Optional[int]] = dims - -# mesh_device_rows, mesh_device_cols = self.mesh_device.shape -# if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: -# raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - -# def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: -# """ -# Map the input tensor to a list of sharded tensors. - -# Args: -# tensor: The input tensor to be sharded. - -# Returns: -# A list of sharded tensors, one for each device in the mesh. - -# Raises: -# ValueError: If the number of sharding dimensions is not 2. -# """ -# import torch - -# if len(self.dims) != 2: -# raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - -# rows, cols = self.mesh_shape -# row_dim, col_dim = self.dims - -# # Shard along rows -# row_tensors = ( -# [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) -# ) - -# # Shard along columns -# if col_dim is None: -# return [t.clone() for t in row_tensors for _ in range(cols)] -# tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - -# if len(tensor_shards) != rows * cols: -# raise ValueError( -# f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." -# ) - -# return tensor_shards - -# def config(self) -> Dict[str, str]: -# """ -# Provide the configuration of the sharding strategy. - -# Returns: -# A dictionary containing the sharding strategy and dimensions. -# """ -# return { -# "strategy": "shard_2d", -# "mesh_shape_y": str(self.mesh_shape[0]), -# "mesh_shape_x": str(self.mesh_shape[1]), -# } - - -# class ConcatMesh2dToTensor(MeshToTensor): -# """ -# Concatenate tensors from a 2D mesh back into a single tensor. - -# This class implements the inverse operation of ShardTensor2dMesh, combining -# sharded tensors from a 2D device mesh back into a single tensor. -# """ - -# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): -# """ -# Initialize the ConcatMesh2dToTensor. - -# Args: -# mesh_device: The source device mesh containing the sharded tensors. -# mesh_shape: The shape of the 2D mesh as (rows, cols). -# dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. -# The first element (row_dim) indicates the dimension for concatenating tensors from different rows. -# The second element (col_dim) indicates the dimension for concatenating tensors from different columns. -# Both dimensions must be specified and different from each other. -# These dimensions correspond to the tensor dimensions, not the mesh dimensions. -# For example, if the original tensor was 4D with shape (batch, channel, height, width), -# and it was sharded across height and width, dims might be (-2, -1) or (2, 3). - -# Raises: -# ValueError: If either dimension in 'dims' is None or if both dimensions are the same. -# """ -# self.mesh_device = mesh_device -# self.mesh_shape = mesh_shape -# self.dims = dims -# if self.dims[0] == self.dims[1]: -# raise ValueError("Both dimensions in 'dims' must be different") - -# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": -# """ -# Compose the sharded tensors back into a single tensor. - -# Args: -# tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. - -# Returns: -# A single torch.Tensor that combines all the sharded tensors from all devices. - -# This method first concatenates the shards along the column dimension within each row, -# then concatenates the resulting tensors along the row dimension to form the final tensor. -# """ -# import torch - -# device_shards = [ -# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) -# ] - -# rows, cols = self.mesh_shape -# row_dim, col_dim = self.dims - -# # Reshape the list of shards into a 2D list representing the device mesh -# mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] - -# # Concatenate along columns first (within each row) -# row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] - -# # Then concatenate the resulting tensors along rows -# return torch.cat(row_concatenated, dim=row_dim) - - -# class ReplicateTensorToMesh(TensorToMesh): -# def __init__(self, mesh_device: MeshDevice): -# super().__init__(mesh_device) - -# def map(self, tensor: "torch.Tensor"): -# return [tensor for i in range(self.mesh_device.get_num_devices())] - -# def config(self): -# return { -# "strategy": "replicate", -# "replication_factor": str(self.mesh_device.get_num_devices()), -# } - - -# class ConcatMeshToTensor(MeshToTensor): -# def __init__(self, mesh_device: MeshDevice, dim: int): -# self.concat_dim = dim -# self.mesh_device = mesh_device - -# def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": -# import torch - -# device_shards_converted_to_torch = [ -# ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) -# ] -# return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) - - @contextlib.contextmanager def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): """ From 0b679dbef18147f46f2251aeaff97ad7c43bca75 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 18 Feb 2025 20:24:57 +0000 Subject: [PATCH 33/76] fix rebase --- .../ttnn/distributed/distributed_pybind.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 9a5107b65f6..9e0e624526a 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -13,6 +13,7 @@ ======= ======= #include +<<<<<<< HEAD >>>>>>> one type error left #include #include @@ -28,6 +29,11 @@ >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu ======= +#include +>>>>>>> fix rebase +======= +#include + #include >>>>>>> fix rebase #include @@ -69,12 +75,12 @@ #include ======= #include "ttnn/tensor/tensor.hpp" -<<<<<<< HEAD -#include "ttnn/types.hpp" -======= #include +<<<<<<< HEAD >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu +======= +>>>>>>> fix rebase // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -908,11 +914,14 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device. +<<<<<<< HEAD <<<<<<< HEAD ======= <<<<<<< HEAD >>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu +======= +>>>>>>> fix rebase Args: tensor (Tensor): The tensor to get the shard from. device (Device): The device to get the shard for. @@ -921,15 +930,6 @@ void py_module(py::module& module) { Returns: Tensor: The shard of the tensor corresponding to the device. )doc"); -======= - Args: - tensor (Tensor): The tensor to get the shard from. - device (Device): The device to get the shard for. -aggregate_as - Returns: - Tensor: The shard of the tensor corresponding to the device. - )doc"); ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); <<<<<<< HEAD <<<<<<< HEAD From c8feeae4c80800336087aad29b3488f076f35cba Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 19 Feb 2025 00:35:00 +0000 Subject: [PATCH 34/76] add shard2dconfig, concat2dconfig methods and map/compose constructors --- .../ttnn/distributed/distributed_pybind.cpp | 78 +++++++++++++++++-- .../ttnn/distributed/distributed_tensor.cpp | 8 ++ .../ttnn/distributed/distributed_tensor.hpp | 8 +- 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 9e0e624526a..912bf75939d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -169,6 +169,9 @@ void py_module_types(py::module& module) { py::class_(module, "AllGatherTensor"); py::class_(module, "DistributedTensorConfig"); + py::class_(module, "Shard2dConfig"); + py::class_(module, "Concat2dConfig"); + py::class_>(module, "MeshDevice"); py::class_(module, "MeshSubDeviceManagerId"); py::class_(module, "MeshShape", "Shape of a mesh device."); @@ -653,11 +656,9 @@ void py_module(py::module& module) { .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); - auto py_replicate_tensor_to_mesh = static_cast>>( module.attr("ReplicateTensorToMesh")); - py_replicate_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device) -> std::unique_ptr { @@ -674,7 +675,6 @@ void py_module(py::module& module) { [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); - auto py_shard_tensor_to_mesh = static_cast>>( module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh @@ -695,7 +695,6 @@ void py_module(py::module& module) { [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = static_cast>>( module.attr("ShardTensor2dMesh")); py_shard_tensor_to_2d_mesh @@ -706,6 +705,18 @@ void py_module(py::module& module) { const Shard2dConfig& config) -> std::unique_ptr { return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); }), + py::init( + [](MeshDevice& mesh_device, + const MeshShape& mesh_shape, + const std::tuple& config) -> std::unique_ptr { + return std::make_unique(ShardTensor2dMesh( + mesh_device, + mesh_shape, + Shard2dConfig{ + .row_dim = std::get<0>(config), + .col_dim = std::get<1>(config), + })); + }), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("config")) @@ -721,13 +732,11 @@ void py_module(py::module& module) { [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ShardTensor2dMesh::config); - auto py_mesh_to_tensor = static_cast>>( module.attr("MeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); - auto py_concat_mesh_to_tensor = static_cast>>( module.attr("ConcatMeshToTensor")); py_concat_mesh_to_tensor @@ -740,7 +749,6 @@ void py_module(py::module& module) { "compose", [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, py::arg("tensors")); - auto py_concat_2d_mesh_to_tensor = static_cast>>( module.attr("ConcatMesh2dToTensor")); @@ -768,6 +776,24 @@ void py_module(py::module& module) { config.col_dim); return std::make_unique(mesh_device, config); }), + py::init( + [](MeshDevice& mesh_device, + const std::tuple config) -> std::unique_ptr { + int row_dim = std::get<0>(config); + int col_dim = std::get<1>(config); + TT_FATAL( + row_dim != col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + row_dim, + col_dim); + return std::make_unique( + mesh_device, + Concat2dConfig{ + .row_dim = row_dim, + .col_dim = col_dim, + }); + }), + py::arg("mesh_device"), py::arg("config")) .def( @@ -827,6 +853,7 @@ void py_module(py::module& module) { py_shard_tensor_config.def(py::init(), py::arg("shard_dimension")) .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) .def("__eq__", [](const ShardTensor& a, const ShardTensor& b) { return a == b; }); +<<<<<<< HEAD <<<<<<< HEAD auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); @@ -849,25 +876,38 @@ void py_module(py::module& module) { .def_readwrite("col_dim", &Concat2dConfig::col_dim); ======= +======= +>>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); - auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); py_shard_tensor2d.def(py::init(), py::arg("mesh")) .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); +<<<<<<< HEAD <<<<<<< HEAD auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")); .def(py::init<>()).def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); >>>>>>> add configs to pybind ======= +======= +>>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")) .def(py::init<>()) .def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); >>>>>>> fix test mappers, convert to cpu_tensor + auto py_shard2d_config = static_cast>(module.attr("Shard2dConfig")); + py_shard2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) + .def_readwrite("row_dim", &Shard2dConfig::row_dim) + .def_readwrite("col_dim", &Shard2dConfig::col_dim); + auto py_concat2d_config = static_cast>(module.attr("Concat2dConfig")); + py_concat2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) + .def_readwrite("row_dim", &Concat2dConfig::row_dim) + .def_readwrite("col_dim", &Concat2dConfig::col_dim); + module.def( "get_distributed_tensor_config", &get_distributed_tensor_config, @@ -905,6 +945,28 @@ void py_module(py::module& module) { )doc"); ======= >>>>>>> add configs to pybind + module.def( + "get_shard2d_config", + &get_shard2d_config, + py::arg("metadata"), + R"doc( + Returns a Shard2dConfig object given a valid metadata object of the type + { + "row_dim": "field", + "col_dim": "field", + } + )doc"); + module.def( + "get_concat2d_config", + &get_concat2d_config, + py::arg("metadata"), + R"doc( + Returns a Concat2dConfig object given a valid metadata object of the type + { + "row_dim": "field", + "col_dim": "field", + } + )doc"); module.def( "get_device_tensor", py::overload_cast(&ttnn::distributed::get_device_tensor), diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index dccae79d68f..7bc781cff30 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -222,11 +222,19 @@ Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { : composer.compose({tensor}); } +<<<<<<< HEAD Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { +======= +static Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { + return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); +} + +static Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { +>>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors return Concat2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 7ae8d24cb5b..1c8daf0d6a8 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -101,7 +101,11 @@ class ShardTensor2dMesh : public TensorToMesh { } ShardTensor2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) {} + mesh_shape_(mesh_shape), config_(config) { + TT_FATAL( + config.row_dim.has_value() || config.col_dim.has_value(), + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); + } std::vector map(const Tensor& tensor) const override { const auto [rows, cols] = mesh_shape_; @@ -147,7 +151,7 @@ class ShardTensor2dMesh : public TensorToMesh { } tt::tt_metal::DistributedTensorConfig config() const override { - return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; + return DistributedTensorConfig{ShardTensor2D{ShardMesh{.y = mesh_shape_.num_rows, .x = mesh_shape_.num_cols}}}; } private: From 1d53fb9683fb60981599a5a97f49214060cad228 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 19 Feb 2025 22:00:52 +0000 Subject: [PATCH 35/76] Replace none types, expose configs, fix tuple errors --- .../distributed/test_distributed_tensor.py | 4 +- .../ttnn/distributed/distributed_pybind.cpp | 136 +++++++++++++++--- .../ttnn/distributed/distributed_tensor.cpp | 8 ++ 3 files changed, 130 insertions(+), 18 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 3818a304470..b5456cfa0d2 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -126,7 +126,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] @@ -173,7 +173,7 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (None, 3) if K < N else (3, None) # None means to replicate along this dim + shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim concat_dim = (3, 1) if K < N else (1, 3) # dim 1 for reduce, dim 3 for concatenating fractures K = K // mesh_shape[1] if K < N else K // mesh_shape[0] diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 912bf75939d..7d0647d02e4 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -68,6 +68,7 @@ >>>>>>> fix test mappers, convert to cpu_tensor #include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" +#include "ttnn/distributed/types.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/tensor/tensor_utils.hpp" >>>>>>> one type error left @@ -188,6 +189,7 @@ void py_module(py::module& module) { py::arg("num_rows"), py::arg("num_cols")) .def( +<<<<<<< HEAD <<<<<<< HEAD py::init([](size_t x, size_t y, size_t z) { return MeshShape(x, y, z); }), "Constructor with the specified 3D shape.", @@ -199,6 +201,8 @@ void py_module(py::module& module) { "Constructor with the specified ND shape.", py::arg("shape")) ======= +======= +>>>>>>> Replace none types, expose configs, fix tuple errors py::init([](const std::tuple& dims) { return MeshShape(std::get<0>(dims), std::get<1>(dims)); }), "Constructor with specified number of rows and columns as a tuple (rows, columns).", py::arg("dims")) @@ -703,23 +707,24 @@ void py_module(py::module& module) { [](MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensor2dMesh(mesh_device, mesh_shape, config)); + return std::make_unique(mesh_device, mesh_shape, config); }), + py::arg("mesh_device"), + py::arg("mesh_shape"), + py::arg("config")) + .def( py::init( [](MeshDevice& mesh_device, - const MeshShape& mesh_shape, - const std::tuple& config) -> std::unique_ptr { - return std::make_unique(ShardTensor2dMesh( + const std::tuple dims, + const MeshShape& mesh_shape) -> std::unique_ptr { + return std::make_unique( mesh_device, mesh_shape, - Shard2dConfig{ - .row_dim = std::get<0>(config), - .col_dim = std::get<1>(config), - })); + Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("config")) + py::arg("dims"), + py::arg("mesh_shape")) .def( py::init( [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { @@ -776,11 +781,13 @@ void py_module(py::module& module) { config.col_dim); return std::make_unique(mesh_device, config); }), + py::arg("mesh_device"), + py::arg("config")) + .def( py::init( - [](MeshDevice& mesh_device, - const std::tuple config) -> std::unique_ptr { - int row_dim = std::get<0>(config); - int col_dim = std::get<1>(config); + [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + int row_dim = std::get<0>(dims); + int col_dim = std::get<1>(dims); TT_FATAL( row_dim != col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", @@ -793,12 +800,59 @@ void py_module(py::module& module) { .col_dim = col_dim, }); }), + py::arg("mesh_device"), + py::arg("dims")) + .def( + py::init( + [](MeshDevice& mesh_device, + const Concat2dConfig& config, + MeshShape& mesh_shape) -> std::unique_ptr { + TT_FATAL( + config.row_dim != config.col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + config.row_dim, + config.col_dim); + TT_FATAL( + mesh_shape.num_rows <= mesh_device.shape().num_rows && // + mesh_shape.num_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + return std::make_unique(mesh_device, config); + }), py::arg("mesh_device"), - py::arg("config")) + py::arg("config"), + py::arg("mesh_shape")) + .def( + py::init( + [](MeshDevice& mesh_device, + const std::tuple dims, + MeshShape& mesh_shape) -> std::unique_ptr { + int row_dim = std::get<0>(dims); + int col_dim = std::get<1>(dims); + + TT_FATAL( + row_dim != col_dim, + "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", + row_dim, + col_dim); + TT_FATAL( + mesh_shape.num_rows <= mesh_device.shape().num_rows && // + mesh_shape.num_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + + return std::make_unique( + mesh_device, + Concat2dConfig{ + .row_dim = row_dim, + .col_dim = col_dim, + }); + }), + py::arg("mesh_device"), + py::arg("dims"), + py::arg("mesh_shape")) .def( "compose", - [](ConcatMesh2dToTensor self, const std::vector& tensors) -> Tensor { + [](const ConcatMesh2dToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); @@ -1028,6 +1082,9 @@ void py_module(py::module& module) { py::arg("config")); module.def( <<<<<<< HEAD +<<<<<<< HEAD +======= +>>>>>>> Replace none types, expose configs, fix tuple errors "shard_tensor_to_2d_mesh_mapper", [](MeshDevice& mesh_device, const std::tuple mesh_shape, @@ -1052,8 +1109,11 @@ void py_module(py::module& module) { TensorToMesh: The created ShardTensor2dMesh mapper. )doc"); module.def( +<<<<<<< HEAD ======= >>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors +======= +>>>>>>> Replace none types, expose configs, fix tuple errors "concat_mesh_to_tensor_composer", [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, py::arg("dim")); @@ -1094,6 +1154,50 @@ void py_module(py::module& module) { )doc"); ======= >>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error + module.def( + "concat_2d_mesh_to_tensor_composer", + [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + return concat_2d_mesh_to_tensor_composer( + mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("dims"), + R"doc( + Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the composer for. + dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. + + Returns: + TensorToMesh: The created ConcatMesh2dToTensor composer. + )doc"); + module.def( + "concat_2d_mesh_to_tensor_composer", + [](MeshDevice& mesh_device, + const std::tuple dims, + const std::tuple mesh_shape) -> std::unique_ptr { + TT_FATAL( + std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // + std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + return concat_2d_mesh_to_tensor_composer( + mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); + }, + py::arg("mesh_device"), + py::arg("dims"), + py::arg("mesh_shape"), + R"doc( + Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. + + Args: + mesh_device (MeshDevice): The mesh device to create the composer for. + dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. + mesh_shape (Tuple[int, int]): The shape of the 2D mesh as (num_rows, num_cols). + + Returns: + TensorToMesh: The created ConcatMesh2dToTensor composer. + )doc"); module.def( "distribute_tensor", [](const Tensor& tensor, diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 7bc781cff30..f716efeeab8 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -222,6 +222,7 @@ Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { : composer.compose({tensor}); } +<<<<<<< HEAD <<<<<<< HEAD Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); @@ -235,6 +236,13 @@ static Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { >>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors +======= +Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { + return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); +} + +Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { +>>>>>>> Replace none types, expose configs, fix tuple errors return Concat2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } From 5a696a396a93197c8874785db52b1124c5ded6e9 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 19 Feb 2025 22:41:08 +0000 Subject: [PATCH 36/76] overload for concatmeshtotensor with meshdevice --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 11 +++++++++++ ttnn/cpp/ttnn/distributed/distributed_tensor.hpp | 13 +++++++++++++ 2 files changed, 24 insertions(+) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 7d0647d02e4..b0943aebe04 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -750,10 +750,21 @@ void py_module(py::module& module) { return std::make_unique(dim); }), py::arg("dim")) + .def( + py::init([](MeshDevice mesh_device, int dim) -> std::unique_ptr { + return std::make_unique(mesh_device, dim); + }), + py::arg("mesh_device"), + py::arg("dim")) .def( "compose", [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, + py::arg("tensors")) + .def( + "compose", + [](const ConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, py::arg("tensors")); + auto py_concat_2d_mesh_to_tensor = static_cast>>( module.attr("ConcatMesh2dToTensor")); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 1c8daf0d6a8..e8ef98b928f 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -171,6 +171,19 @@ class ConcatMeshToTensor : public MeshToTensor { int concat_dim_ = -1; }; +class DeviceConcatMeshToTensor : public ConcatMeshToTensor { +public: + DeviceConcatMeshToTensor(MeshDevice mesh_device, int dim) : mesh_device_(mesh_device), concat_dim_(dim) {} + + Tensor compose(const Tensor& tensor) { + return experimental::xtensor::concat(get_device_tensors(tensor), concat_dim_); + } + +private: + MeshDevice mesh_device_; + int concat_dim_ = -1; +}; + class ConcatMesh2dToTensor : public MeshToTensor { public: ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : From bcf45087413eccd8fd25f135dc6cb77c9045b394 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 05:08:54 +0000 Subject: [PATCH 37/76] remove extraneous comments --- tests/ttnn/distributed/test_distributed_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index b5456cfa0d2..614ffeffd03 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -126,7 +126,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim + shard_dim = (0, 3) if K < N else (3, 0) K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] @@ -173,8 +173,8 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) # None means to replicate along this dim - concat_dim = (3, 1) if K < N else (1, 3) # dim 1 for reduce, dim 3 for concatenating fractures + shard_dim = (0, 3) if K < N else (3, 0) + concat_dim = (3, 1) if K < N else (1, 3) K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] From 4c89683dd01a78cb3d5af72430074ac7aa69a7f3 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 05:47:43 +0000 Subject: [PATCH 38/76] fix deviceconcat errors --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 4 ++-- ttnn/cpp/ttnn/distributed/distributed_tensor.hpp | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index b0943aebe04..af7997d5dda 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -751,7 +751,7 @@ void py_module(py::module& module) { }), py::arg("dim")) .def( - py::init([](MeshDevice mesh_device, int dim) -> std::unique_ptr { + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { return std::make_unique(mesh_device, dim); }), py::arg("mesh_device"), @@ -762,7 +762,7 @@ void py_module(py::module& module) { py::arg("tensors")) .def( "compose", - [](const ConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, + [](const DeviceConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index e8ef98b928f..a4eaf552571 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -173,14 +173,15 @@ class ConcatMeshToTensor : public MeshToTensor { class DeviceConcatMeshToTensor : public ConcatMeshToTensor { public: - DeviceConcatMeshToTensor(MeshDevice mesh_device, int dim) : mesh_device_(mesh_device), concat_dim_(dim) {} + DeviceConcatMeshToTensor(MeshDevice& mesh_device, int dim) : + ConcatMeshToTensor(dim), mesh_device_(mesh_device), concat_dim_(dim) {} - Tensor compose(const Tensor& tensor) { + Tensor compose(const Tensor& tensor) const { return experimental::xtensor::concat(get_device_tensors(tensor), concat_dim_); } private: - MeshDevice mesh_device_; + MeshDevice& mesh_device_; int concat_dim_ = -1; }; From a6d2016f57cbcebf55d8eb241e1b10beb4968f9f Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 17:16:07 +0000 Subject: [PATCH 39/76] add back distributed.py for now, clean up class overloads --- .../distributed/test_distributed_tensor.py | 12 +- .../ttnn/distributed/distributed_pybind.cpp | 196 +++++----------- .../ttnn/distributed/distributed_tensor.cpp | 8 + .../ttnn/distributed/distributed_tensor.hpp | 47 +--- ttnn/ttnn/__init__.py | 14 +- ttnn/ttnn/distributed/__init__.py | 1 + ttnn/ttnn/distributed/distributed.py | 222 +++++++++++++++++- 7 files changed, 308 insertions(+), 192 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 614ffeffd03..4d3b593287e 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -21,7 +21,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 8192) + torch_tensor = torch.randn(1, 1, 32, 256) to_repl = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -42,7 +42,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 8192, 32768) + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -65,7 +65,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 8192, 32768) + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -88,7 +88,7 @@ def test_concat_to_tensor(mesh_device, dtype): def test_concat_slice_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 8192, 32768) + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -116,7 +116,7 @@ def test_concat_slice_to_tensor(mesh_device, dtype): ) @pytest.mark.parametrize( "M, K, N", - [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], ) @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): @@ -163,7 +163,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): ) @pytest.mark.parametrize( "M, K, N", - [pytest.param(32, 8192, 28 * 1024), pytest.param(32, 28 * 1024, 8192)], + [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], ) @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index af7997d5dda..101ed4e1739 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -115,6 +115,7 @@ struct ConcreteMeshToTensor : MeshToTensor { void py_module_types(py::module& module) { <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD py::class_>(module, "CppMeshToTensor"); py::class_>(module, "CppTensorToMesh"); @@ -162,6 +163,19 @@ void py_module_types(py::module& module) { py::class_>( module, "ConcatMesh2dToTensor"); >>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice +======= + py::class_>(module, "CppMeshToTensor"); + py::class_>(module, "CppTensorToMesh"); + + py::class_>( + module, "CppReplicateTensorToMesh"); + py::class_>(module, "CppShardTensorToMesh"); + py::class_>( + module, "CppShardTensorTo2dMesh"); + py::class_>(module, "CppConcatMeshToTensor"); + py::class_>( + module, "CppConcat2dMeshToTensor"); +>>>>>>> add back distributed.py for now, clean up class overloads py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); @@ -654,120 +668,100 @@ void py_module(py::module& module) { >>>>>>> one type error left ======= auto py_tensor_to_mesh = static_cast>>( +<<<<<<< HEAD module.attr("TensorToMesh")); >>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors +======= + module.attr("CppTensorToMesh")); +>>>>>>> add back distributed.py for now, clean up class overloads py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); auto py_replicate_tensor_to_mesh = static_cast>>( - module.attr("ReplicateTensorToMesh")); + module.attr("CppReplicateTensorToMesh")); py_replicate_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device) -> std::unique_ptr { return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); }), py::arg("mesh_device")) - .def( - py::init([](size_t num_devices) -> std::unique_ptr { - return std::make_unique(ReplicateTensorToMesh(num_devices)); - }), - py::arg("num_devices")) .def( "map", [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); auto py_shard_tensor_to_mesh = static_cast>>( - module.attr("ShardTensorToMesh")); + module.attr("CppShardTensorToMesh")); py_shard_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(ShardTensorToMesh(mesh_device, dim)); + return std::make_unique(ShardTensorToMesh(mesh_device.num_devices(), dim)); }), py::arg("mesh_device"), py::arg("dim")) - .def( - py::init([](size_t num_devices, int dim) -> std::unique_ptr { - return std::make_unique(ShardTensorToMesh(num_devices, dim)); - }), - py::arg("num_devices"), - py::arg("dim")) .def( "map", [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = static_cast>>( - module.attr("ShardTensor2dMesh")); + auto py_shard_tensor_to_2d_mesh = + static_cast>>( + module.attr("CppShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, - const MeshShape& mesh_shape, - const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(mesh_device, mesh_shape, config); - }), - py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("config")) - .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple dims, - const MeshShape& mesh_shape) -> std::unique_ptr { - return std::make_unique( - mesh_device, - mesh_shape, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { + int shape_rows = std::get<0>(mesh_shape); + int shape_cols = std::get<1>(mesh_shape); + + int config_rows = std::get<0>(dims); + int config_cols = std::get<1>(dims); + TT_FATAL( + config_rows || config_cols, + "Sharding a tensor to 2D mesh requires at least one dimension to shard"); + TT_FATAL( + shape_rows <= mesh_device.shape().num_rows && // + shape_cols <= mesh_device.shape().num_cols, + "Device mesh shape does not match the provided mesh shape."); + + return std::make_unique( + MeshShape{.num_rows = shape_rows, .num_cols = shape_cols}, Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), py::arg("mesh_device"), - py::arg("dims"), - py::arg("mesh_shape")) - .def( - py::init( - [](const MeshShape& mesh_shape, const Shard2dConfig& config) -> std::unique_ptr { - return std::make_unique(ShardTensor2dMesh(mesh_shape, config)); - }), py::arg("mesh_shape"), - py::arg("config")) + py::arg("dims")) .def( "map", - [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) - .def("config", &ShardTensor2dMesh::config); + .def("config", &ShardTensorTo2dMesh::config); auto py_mesh_to_tensor = static_cast>>( - module.attr("MeshToTensor")); + module.attr("CppMeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); auto py_concat_mesh_to_tensor = static_cast>>( - module.attr("ConcatMeshToTensor")); + module.attr("CppConcatMeshToTensor")); py_concat_mesh_to_tensor - .def( - py::init([](int dim) -> std::unique_ptr { - return std::make_unique(dim); - }), - py::arg("dim")) .def( py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(mesh_device, dim); + return std::make_unique(dim); }), py::arg("mesh_device"), py::arg("dim")) .def( "compose", [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, - py::arg("tensors")) - .def( - "compose", - [](const DeviceConcatMeshToTensor& self, const Tensor& tensor) { return self.compose(tensor); }, py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("ConcatMesh2dToTensor")); + static_cast>>( + module.attr("CppConcat2dMeshToTensor")); py_concat_2d_mesh_to_tensor <<<<<<< HEAD .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { @@ -784,74 +778,22 @@ void py_module(py::module& module) { ======= .def( py::init( - [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { - TT_FATAL( - config.row_dim != config.col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - config.row_dim, - config.col_dim); - return std::make_unique(mesh_device, config); - }), - py::arg("mesh_device"), - py::arg("config")) - .def( - py::init( - [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + [](MeshDevice& mesh_device, + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { int row_dim = std::get<0>(dims); int col_dim = std::get<1>(dims); TT_FATAL( - row_dim != col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - row_dim, - col_dim); - return std::make_unique( - mesh_device, - Concat2dConfig{ - .row_dim = row_dim, - .col_dim = col_dim, - }); - }), - py::arg("mesh_device"), - py::arg("dims")) - .def( - py::init( - [](MeshDevice& mesh_device, - const Concat2dConfig& config, - MeshShape& mesh_shape) -> std::unique_ptr { - TT_FATAL( - config.row_dim != config.col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - config.row_dim, - config.col_dim); - TT_FATAL( - mesh_shape.num_rows <= mesh_device.shape().num_rows && // - mesh_shape.num_cols <= mesh_device.shape().num_cols, + std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // + std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, "Device mesh shape does not match the provided mesh shape."); - return std::make_unique(mesh_device, config); - }), - py::arg("mesh_device"), - py::arg("config"), - py::arg("mesh_shape")) - .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple dims, - MeshShape& mesh_shape) -> std::unique_ptr { - int row_dim = std::get<0>(dims); - int col_dim = std::get<1>(dims); - TT_FATAL( row_dim != col_dim, "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", row_dim, col_dim); - TT_FATAL( - mesh_shape.num_rows <= mesh_device.shape().num_rows && // - mesh_shape.num_cols <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - - return std::make_unique( + return std::make_unique( mesh_device, Concat2dConfig{ .row_dim = row_dim, @@ -859,11 +801,11 @@ void py_module(py::module& module) { }); }), py::arg("mesh_device"), - py::arg("dims"), - py::arg("mesh_shape")) + py::arg("Mesh_shape"), + py::arg("dims")) .def( "compose", - [](const ConcatMesh2dToTensor& self, const std::vector& tensors) -> Tensor { + [](const Concat2dMeshToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); @@ -1165,29 +1107,11 @@ void py_module(py::module& module) { )doc"); ======= >>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error - module.def( - "concat_2d_mesh_to_tensor_composer", - [](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { - return concat_2d_mesh_to_tensor_composer( - mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); - }, - py::arg("mesh_device"), - py::arg("dims"), - R"doc( - Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. - - Args: - mesh_device (MeshDevice): The mesh device to create the composer for. - dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. - - Returns: - TensorToMesh: The created ConcatMesh2dToTensor composer. - )doc"); module.def( "concat_2d_mesh_to_tensor_composer", [](MeshDevice& mesh_device, - const std::tuple dims, - const std::tuple mesh_shape) -> std::unique_ptr { + const std::tuple mesh_shape, + const std::tuple dims) -> std::unique_ptr { TT_FATAL( std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index f716efeeab8..50449ebf7fe 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -163,6 +163,7 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( "Device mesh shape does not match the provided mesh shape."); <<<<<<< HEAD <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD return std::make_unique(mesh_shape[0], mesh_shape[1], config); ======= @@ -174,6 +175,9 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( ======= return std::make_unique(mesh_shape, config); >>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice +======= + return std::make_unique(mesh_shape, config); +>>>>>>> add back distributed.py for now, clean up class overloads } std::unique_ptr concat_mesh_to_tensor_composer(int dim) { @@ -188,6 +192,7 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh config.col_dim); <<<<<<< HEAD <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD TT_FATAL(mesh_device.shape().dims() == 2, "Mesh device is not configured as a 2D mesh: {}", mesh_device.shape()); return std::make_unique(mesh_device.shape()[0], mesh_device.shape()[1], config); @@ -200,6 +205,9 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh ======= return std::make_unique(mesh_device, config); >>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice +======= + return std::make_unique(mesh_device, config); +>>>>>>> add back distributed.py for now, clean up class overloads } Tensor distribute_tensor( diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index a4eaf552571..806168721e9 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -51,8 +51,6 @@ class ReplicateTensorToMesh : public TensorToMesh { public: ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {} - ReplicateTensorToMesh(MeshDevice& mesh_device) : num_devices_(mesh_device.num_devices()) {} - std::vector map(const Tensor& tensor) const override { std::vector tensors; tensors.reserve(num_devices_); @@ -72,8 +70,6 @@ class ShardTensorToMesh : public TensorToMesh { public: ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} - ShardTensorToMesh(MeshDevice& mesh_device, int dim) : num_devices_(mesh_device.num_devices()), shard_dim_(dim) {} - std::vector map(const Tensor& tensor) const override { return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); } @@ -87,25 +83,10 @@ class ShardTensorToMesh : public TensorToMesh { int shard_dim_ = -1; }; -class ShardTensor2dMesh : public TensorToMesh { +class ShardTensorTo2dMesh : public TensorToMesh { public: - ShardTensor2dMesh(MeshDevice& mesh_device, const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) { - TT_FATAL( - config.row_dim.has_value() || config.col_dim.has_value(), - "Sharding a tensor to 2D mesh requires at least one dimension to shard"); - TT_FATAL( - mesh_shape.num_rows <= mesh_device.shape().num_rows && // - mesh_shape.num_cols <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - } - - ShardTensor2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) { - TT_FATAL( - config.row_dim.has_value() || config.col_dim.has_value(), - "Sharding a tensor to 2D mesh requires at least one dimension to shard"); - } + ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : + mesh_shape_(mesh_shape), config_(config) {} std::vector map(const Tensor& tensor) const override { const auto [rows, cols] = mesh_shape_; @@ -141,7 +122,7 @@ class ShardTensor2dMesh : public TensorToMesh { TT_FATAL( static_cast(tensor_shards.size()) == rows * cols, - "ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh " + "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " "dimensions. Size: {}, rows: {}, cols: {}", tensor_shards.size(), rows, @@ -151,7 +132,7 @@ class ShardTensor2dMesh : public TensorToMesh { } tt::tt_metal::DistributedTensorConfig config() const override { - return DistributedTensorConfig{ShardTensor2D{ShardMesh{.y = mesh_shape_.num_rows, .x = mesh_shape_.num_cols}}}; + return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; } private: @@ -171,23 +152,9 @@ class ConcatMeshToTensor : public MeshToTensor { int concat_dim_ = -1; }; -class DeviceConcatMeshToTensor : public ConcatMeshToTensor { -public: - DeviceConcatMeshToTensor(MeshDevice& mesh_device, int dim) : - ConcatMeshToTensor(dim), mesh_device_(mesh_device), concat_dim_(dim) {} - - Tensor compose(const Tensor& tensor) const { - return experimental::xtensor::concat(get_device_tensors(tensor), concat_dim_); - } - -private: - MeshDevice& mesh_device_; - int concat_dim_ = -1; -}; - -class ConcatMesh2dToTensor : public MeshToTensor { +class Concat2dMeshToTensor : public MeshToTensor { public: - ConcatMesh2dToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : + Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : mesh_shape_(mesh_device.shape()), config_(config) {} Tensor compose(const std::vector& tensors) const override { diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 312eeb61551..18dbbc78cce 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -96,13 +96,13 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( MeshDevice, - MeshToTensor, - TensorToMesh, - ReplicateTensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ConcatMeshToTensor, - ConcatMesh2dToTensor, + # CppMeshToTensor, + # CppTensorToMesh, + # CppReplicateTensorToMesh, + # CppShardTensorToMesh, + # CppShardTensorTo2dMesh, + # CppConcatMeshToTensor, + # CppConcat2dMeshToTensor, ReplicateTensor, ShardTensor, ShardTensor2d, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index c1fa3c25670..4901c6ae8cb 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -4,6 +4,7 @@ # TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days from .distributed import ( + MeshDevice, DispatchCoreType, TensorToMesh, ShardTensorToMesh, diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index b29089e72f6..fa057bd0051 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -209,8 +209,224 @@ def synchronize_devices( ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) +# TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days +class TensorToMesh: + """ + Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. + You can also "Bring your own TensorToMesh" based on your custom mapping. + """ + + def __init__(self, mesh_device): + self.mesh_device = mesh_device + + def map(self, tensor: "torch.Tensor"): + raise NotImplementedError("Subclasses must implement this method") + + def config(self): + raise NotImplementedError("Subclasses must implement this method") + + +class MeshToTensor: + """ + Defines the inverse operation of TensorToMesh. Given a set of per-device + ttnn.Tensor objects (aggregated into a single ttnn.Tensor), this class defines + the mapping back to one or many torch.Tensor objects. + You can also "Bring your own MeshToTensor" based on your custom mapping. + """ + + def compose(self, tensor: ttnn.Tensor): + raise NotImplementedError("Subclasses must implement this method") + + +class ShardTensorToMesh(TensorToMesh): + def __init__(self, mesh_device, dim): + super().__init__(mesh_device) + self.shard_dim = dim + + def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: + import torch + + sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) + return list(sliced_tensors) + + def config(self): + return { + "strategy": "shard", + "shard_dim": f"{self.shard_dim}", + } + + +class ShardTensor2dMesh(TensorToMesh): + """ + Shard a tensor across a 2D mesh of devices. + This class implements a strategy for distributing a tensor across a 2D grid of devices, + allowing for efficient parallel processing in distributed computing environments. + """ + + def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): + """ + Initialize the ShardTensor2dMesh. + Args: + mesh_device: The target device mesh for distributing the tensor. + mesh_shape: The shape of the 2D mesh as (rows, cols). + dims: The dimensions to shard along, specified as (row_dim, col_dim). + The `dims` tuple determines how the tensor is sharded across the 2D mesh: + - row_dim: The dimension to shard across mesh rows (or None for replication). + - col_dim: The dimension to shard across mesh columns (or None for replication). + Examples: + 1. dims=(2, 3) for a tensor of shape (A, B, C, D): + - Shard along dimension 2 (C) across mesh rows + - Shard along dimension 3 (D) across mesh columns + 2. dims=(None, 3): + - Replicate across mesh rows + - Shard along dimension 3 (D) across mesh columns + 3. dims=(None, None): + - Fully replicate the tensor across all devices + """ + super().__init__(mesh_device) + self.mesh_shape: Tuple[int, int] = mesh_shape + self.dims: Tuple[Optional[int], Optional[int]] = dims + + mesh_device_rows, mesh_device_cols = self.mesh_device.shape + if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: + raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") + + def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: + """ + Map the input tensor to a list of sharded tensors. + Args: + tensor: The input tensor to be sharded. + Returns: + A list of sharded tensors, one for each device in the mesh. + Raises: + ValueError: If the number of sharding dimensions is not 2. + """ + import torch + + if len(self.dims) != 2: + raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") + + rows, cols = self.mesh_shape + row_dim, col_dim = self.dims + + # Shard along rows + row_tensors = ( + [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) + ) + + # Shard along columns + if col_dim is None: + return [t.clone() for t in row_tensors for _ in range(cols)] + tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] + + if len(tensor_shards) != rows * cols: + raise ValueError( + f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." + ) + + return tensor_shards + + def config(self) -> Dict[str, str]: + """ + Provide the configuration of the sharding strategy. + Returns: + A dictionary containing the sharding strategy and dimensions. + """ + return { + "strategy": "shard_2d", + "mesh_shape_y": str(self.mesh_shape[0]), + "mesh_shape_x": str(self.mesh_shape[1]), + } + + +class ConcatMesh2dToTensor(MeshToTensor): + """ + Concatenate tensors from a 2D mesh back into a single tensor. + This class implements the inverse operation of ShardTensor2dMesh, combining + sharded tensors from a 2D device mesh back into a single tensor. + """ + + def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[int, int]): + """ + Initialize the ConcatMesh2dToTensor. + Args: + mesh_device: The source device mesh containing the sharded tensors. + mesh_shape: The shape of the 2D mesh as (rows, cols). + dims: A tuple of two integers specifying the dimensions along which to concatenate the tensors. + The first element (row_dim) indicates the dimension for concatenating tensors from different rows. + The second element (col_dim) indicates the dimension for concatenating tensors from different columns. + Both dimensions must be specified and different from each other. + These dimensions correspond to the tensor dimensions, not the mesh dimensions. + For example, if the original tensor was 4D with shape (batch, channel, height, width), + and it was sharded across height and width, dims might be (-2, -1) or (2, 3). + Raises: + ValueError: If either dimension in 'dims' is None or if both dimensions are the same. + """ + self.mesh_device = mesh_device + self.mesh_shape = mesh_shape + self.dims = dims + if self.dims[0] == self.dims[1]: + raise ValueError("Both dimensions in 'dims' must be different") + + def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": + """ + Compose the sharded tensors back into a single tensor. + Args: + tensor: A ttnn.Tensor object containing the sharded tensors distributed across multiple devices. + Returns: + A single torch.Tensor that combines all the sharded tensors from all devices. + This method first concatenates the shards along the column dimension within each row, + then concatenates the resulting tensors along the row dimension to form the final tensor. + """ + import torch + + device_shards = [ + ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ] + + rows, cols = self.mesh_shape + row_dim, col_dim = self.dims + + # Reshape the list of shards into a 2D list representing the device mesh + mesh_shape = [device_shards[i : i + cols] for i in range(0, len(device_shards), cols)] + + # Concatenate along columns first (within each row) + row_concatenated = [torch.cat(row, dim=col_dim) for row in mesh_shape] + + # Then concatenate the resulting tensors along rows + return torch.cat(row_concatenated, dim=row_dim) + + +class ReplicateTensorToMesh(TensorToMesh): + def __init__(self, mesh_device: MeshDevice): + super().__init__(mesh_device) + + def map(self, tensor: "torch.Tensor"): + return [tensor for i in range(self.mesh_device.get_num_devices())] + + def config(self): + return { + "strategy": "replicate", + "replication_factor": str(self.mesh_device.get_num_devices()), + } + + +class ConcatMeshToTensor(MeshToTensor): + def __init__(self, mesh_device: MeshDevice, dim: int): + self.concat_dim = dim + self.mesh_device = mesh_device + + def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": + import torch + + device_shards_converted_to_torch = [ + ttnn.to_torch(tt_input_tensor, mesh_composer=None) for tt_input_tensor in ttnn.get_device_tensors(tensor) + ] + return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) + + @contextlib.contextmanager -def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): +def distribute(default: Union[TensorToMesh, MeshToTensor]): """ Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. @@ -233,9 +449,9 @@ def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor]): _original_from_torch = ttnn.from_torch try: - if isinstance(default, ttnn.TensorToMesh): + if isinstance(default, TensorToMesh): ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, ttnn.MeshToTensor): + elif isinstance(default, MeshToTensor): ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) else: raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") From 58b8d4623b2d771a964ef62565b506f7ad9ddee2 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 17:25:24 +0000 Subject: [PATCH 40/76] remove unused import --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 101ed4e1739..48d999ad57c 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -7,6 +7,7 @@ <<<<<<< HEAD #include <<<<<<< HEAD +<<<<<<< HEAD <<<<<<< HEAD #include @@ -33,6 +34,8 @@ >>>>>>> fix rebase ======= #include +======= +>>>>>>> remove unused import #include >>>>>>> fix rebase From 1d208e09e9c0e4be6423a01686bff5325ca18c5f Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 20 Feb 2025 20:31:53 +0000 Subject: [PATCH 41/76] rearrange from_torch.py, start migrating cpp classes and testing integration --- .../distributed/test_distributed_tensor.py | 30 ++++++++++++++++++- ttnn/ttnn/__init__.py | 14 ++++----- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 4d3b593287e..b31f411fa27 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. # SPDX-License-Identifier: Apache-2.0 @@ -10,6 +10,34 @@ from models.utility_functions import nearest_32 +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 32, 256) + to_repl = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.CppReplicateTensorToMesh(mesh_device) + replicated_tensors = ttnn.from_torch(to_repl, mapper, mesh_device) + out_tensors = ttnn.get_device_tensors(replicated_tensors) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + @pytest.mark.parametrize( "mesh_device", [ diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 18dbbc78cce..dce198ccc88 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -96,13 +96,13 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( MeshDevice, - # CppMeshToTensor, - # CppTensorToMesh, - # CppReplicateTensorToMesh, - # CppShardTensorToMesh, - # CppShardTensorTo2dMesh, - # CppConcatMeshToTensor, - # CppConcat2dMeshToTensor, + CppMeshToTensor, + CppTensorToMesh, + CppReplicateTensorToMesh, + CppShardTensorToMesh, + CppShardTensorTo2dMesh, + CppConcatMeshToTensor, + CppConcat2dMeshToTensor, ReplicateTensor, ShardTensor, ShardTensor2d, From b910b6d6a783e4ef7777b5dc958e71b99d2f2c6d Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 21 Feb 2025 23:30:18 +0000 Subject: [PATCH 42/76] interim work for supporting mappers --- .../distributed/test_distributed_tensor.py | 390 +++++++++--------- .../ttnn/distributed/distributed_pybind.cpp | 68 ++- ttnn/ttnn/__init__.py | 9 +- ttnn/ttnn/distributed/__init__.py | 8 +- ttnn/ttnn/distributed/distributed.py | 218 +++++----- ttnn/ttnn/operations/core.py | 9 +- 6 files changed, 378 insertions(+), 324 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index b31f411fa27..379d2eecfda 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -1,234 +1,234 @@ -# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest -import ttnn -from loguru import logger -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc -from models.utility_functions import nearest_32 - - -@pytest.mark.parametrize( - "mesh_device", - [ - 32, - ], - indirect=True, -) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): - torch.manual_seed(1234) - - torch_tensor = torch.randn(1, 1, 32, 256) - to_repl = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - ) +# # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# # SPDX-License-Identifier: Apache-2.0 + +# import torch +# import pytest +# import ttnn +# from loguru import logger +# from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +# from models.utility_functions import nearest_32 + + +# @pytest.mark.parametrize( +# "mesh_device", +# [ +# 32, +# ], +# indirect=True, +# ) +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): +# torch.manual_seed(1234) + +# mapper = ttnn.ReplicateTensorToMesh(mesh_device) + +# torch_tensor = torch.randn(1, 1, 32, 256) +# replicated_tensors = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# mesh_mapper = mapper, +# device=mesh_device, +# ) - mapper = ttnn.CppReplicateTensorToMesh(mesh_device) - replicated_tensors = ttnn.from_torch(to_repl, mapper, mesh_device) - out_tensors = ttnn.get_device_tensors(replicated_tensors) +# out_tensors = ttnn.get_device_tensors(replicated_tensors) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass +# @pytest.mark.parametrize( +# "mesh_device", +# [ +# 32, +# ], +# indirect=True, +# ) +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_replicate_to_tensor_mesh(mesh_device, dtype): +# torch.manual_seed(1234) -@pytest.mark.parametrize( - "mesh_device", - [ - 32, - ], - indirect=True, -) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_replicate_to_tensor_mesh(mesh_device, dtype): - torch.manual_seed(1234) +# torch_tensor = torch.randn(1, 1, 32, 256) +# to_repl = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# device=mesh_device, +# ) - torch_tensor = torch.randn(1, 1, 32, 256) - to_repl = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - ) +# mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) +# replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) +# out_tensors = ttnn.get_device_tensors(replicated_tensors) - mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) - replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) - out_tensors = ttnn.get_device_tensors(replicated_tensors) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_shard_to_tensor_mesh(mesh_device, dtype): +# torch.manual_seed(1234) + +# torch_tensor = torch.randn(1, 1, 32, 256) +# to_shard = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# device=mesh_device, +# ) + +# mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + +# shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_shard_to_tensor_mesh(mesh_device, dtype): - torch.manual_seed(1234) - - torch_tensor = torch.randn(1, 1, 32, 256) - to_shard = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - ) - - mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - - shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) +# out_tensor = ttnn.aggregate_as_tensor(shards) - out_tensor = ttnn.aggregate_as_tensor(shards) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_concat_to_tensor(mesh_device, dtype): +# torch.manual_seed(1234) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_concat_to_tensor(mesh_device, dtype): - torch.manual_seed(1234) +# torch_tensor = torch.randn(1, 1, 32, 256) +# to_shard = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# device=mesh_device, +# ) + +# mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + +# composer = ttnn.concat_mesh_to_tensor_composer(dim=3) + +# out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - torch_tensor = torch.randn(1, 1, 32, 256) - to_shard = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - ) - - mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - - composer = ttnn.concat_mesh_to_tensor_composer(dim=3) - - out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass + + +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_concat_slice_to_tensor(mesh_device, dtype): +# torch.manual_seed(1234) + +# torch_tensor = torch.randn(1, 1, 32, 256) +# to_shard = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# device=mesh_device, +# ) + +# mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + +# composer = ttnn.concat_mesh_to_tensor_composer(dim=3) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass - - -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_concat_slice_to_tensor(mesh_device, dtype): - torch.manual_seed(1234) - - torch_tensor = torch.randn(1, 1, 32, 256) - to_shard = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - device=mesh_device, - ) - - mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) +# sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) - composer = ttnn.concat_mesh_to_tensor_composer(dim=3) +# shards = ttnn.get_device_tensors(sharded_tensor) - sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) +# out_tensor = ttnn.aggregate_tensor(shards, composer) - shards = ttnn.get_device_tensors(sharded_tensor) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass - out_tensor = ttnn.aggregate_tensor(shards, composer) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +# @pytest.mark.parametrize( +# "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +# ) +# @pytest.mark.parametrize( +# "M, K, N", +# [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +# ) +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): +# torch.manual_seed(1234) +# torch_tensor = torch.randn(1, 1, M, K) +# core_grid = ttnn.CoreGrid(y=1, x=8) -@pytest.mark.parametrize( - "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] -) -@pytest.mark.parametrize( - "M, K, N", - [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], -) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): - torch.manual_seed(1234) +# # If K < N it's FF1-like test case, else FF2-like test case +# shard_dim = (0, 3) if K < N else (3, 0) - torch_tensor = torch.randn(1, 1, M, K) - core_grid = ttnn.CoreGrid(y=1, x=8) +# K = K // mesh_shape[1] if K < N else K // mesh_shape[0] +# N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) +# sharded_mem_config = ttnn.create_sharded_memory_config( +# shape=(M // core_grid.y, K // core_grid.x), +# core_grid=core_grid, +# strategy=ttnn.ShardStrategy.WIDTH, +# orientation=ttnn.ShardOrientation.ROW_MAJOR, +# use_height_and_width_as_shard_shape=True, +# ) - K = K // mesh_shape[1] if K < N else K // mesh_shape[0] - N = N // mesh_shape[0] if K < N else N // mesh_shape[1] +# to_shard = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, +# device=mesh_device, +# ) - sharded_mem_config = ttnn.create_sharded_memory_config( - shape=(M // core_grid.y, K // core_grid.x), - core_grid=core_grid, - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) +# mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - to_shard = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, - device=mesh_device, - ) +# shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) +# ttnn.aggregate_as_tensor(shards) - shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass - ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +# @pytest.mark.parametrize( +# "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +# ) +# @pytest.mark.parametrize( +# "M, K, N", +# [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], +# ) +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): +# torch.manual_seed(1234) +# torch_tensor = torch.randn(1, 1, M, K) +# core_grid = ttnn.CoreGrid(y=1, x=8) -@pytest.mark.parametrize( - "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] -) -@pytest.mark.parametrize( - "M, K, N", - [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], -) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): - torch.manual_seed(1234) +# # If K < N it's FF1-like test case, else FF2-like test case +# shard_dim = (0, 3) if K < N else (3, 0) +# concat_dim = (3, 1) if K < N else (1, 3) - torch_tensor = torch.randn(1, 1, M, K) - core_grid = ttnn.CoreGrid(y=1, x=8) +# K = K // mesh_shape[1] if K < N else K // mesh_shape[0] +# N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - # If K < N it's FF1-like test case, else FF2-like test case - shard_dim = (0, 3) if K < N else (3, 0) - concat_dim = (3, 1) if K < N else (1, 3) +# sharded_mem_config = ttnn.create_sharded_memory_config( +# shape=(M // core_grid.y, K // core_grid.x), +# core_grid=core_grid, +# strategy=ttnn.ShardStrategy.WIDTH, +# orientation=ttnn.ShardOrientation.ROW_MAJOR, +# use_height_and_width_as_shard_shape=True, +# ) - K = K // mesh_shape[1] if K < N else K // mesh_shape[0] - N = N // mesh_shape[0] if K < N else N // mesh_shape[1] +# to_shard = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, +# device=mesh_device, +# ) - sharded_mem_config = ttnn.create_sharded_memory_config( - shape=(M // core_grid.y, K // core_grid.x), - core_grid=core_grid, - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) +# mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - to_shard = ttnn.from_torch( - torch_tensor, - dtype=dtype, - layout=ttnn.TILE_LAYOUT, - memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, - device=mesh_device, - ) +# composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) - mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - - composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) - - out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) - logger.info(f"PCC value: {out_pcc}") - assert out_pass +# out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) + +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 48d999ad57c..ff90c1d14b5 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -96,6 +96,55 @@ namespace ttnn::distributed { namespace py = pybind11; +// namespace pybind11 { namespace detail { + +// // Helper template that implements conversion from std::unique_ptr to std::unique_ptr +// template +// struct unique_ptr_base_caster { +// // This macro defines the member "value" (of type std::unique_ptr) and a type annotation. +// PYBIND11_TYPE_CASTER(std::unique_ptr, +// py::str("unique_ptr<") + type_id() + py::str(">")); + +// // load() is not supported for unique_ptr conversion from Python. +// bool load(::pybind11::handle, bool) { +// return false; +// } + +// // cast() converts the C++ unique_ptr to a Python object. +// static ::pybind11::handle cast(const std::unique_ptr& src, +// ::pybind11::return_value_policy policy, ::pybind11::handle parent) { +// // Convert the underlying raw pointer from Derived* to Base* +// Base* base_ptr = static_cast(src.get()); +// // Use py::cast on the raw pointer. Note: this does not transfer ownership. +// // The returned handle is then released. +// return py::cast(base_ptr, policy, parent).release(); +// } +// }; + +// }} // namespace pybind11::detail +// namespace pybind11 { namespace detail { + +// template <> +// struct ::pybind11::detail::type_caster> +// : unique_ptr_base_caster {}; + +// template <> +// struct ::pybind11::detail::type_caster> +// : unique_ptr_base_caster {}; + +// template <> +// struct ::pybind11::detail::type_caster> +// : unique_ptr_base_caster {}; + +// struct ::pybind11::detail::type_caster> +// : unique_ptr_base_caster {}; + +// template <> +// struct ::pybind11::detail::type_caster> +// : unique_ptr_base_caster {}; + +// }} // namespace pybind11::detail + // Trampoline class to clear virtual method errors struct ConcreteTensorToMesh : TensorToMesh { using TensorToMesh::TensorToMesh; // Inherit constructors @@ -121,13 +170,12 @@ void py_module_types(py::module& module) { <<<<<<< HEAD <<<<<<< HEAD py::class_>(module, "CppMeshToTensor"); - py::class_>(module, "CppTensorToMesh"); + py::class_>(module, "TensorToMesh"); py::class_>( - module, "CppReplicateTensorToMesh"); - py::class_>(module, "CppShardTensorToMesh"); - py::class_>( - module, "CppShardTensorTo2dMesh"); + module, "ReplicateTensorToMesh"); + py::class_>(module, "ShardTensorToMesh"); + py::class_>(module, "ShardTensorTo2dMesh"); py::class_>(module, "CppConcatMeshToTensor"); py::class_>( module, "CppConcat2dMeshToTensor"); @@ -671,19 +719,23 @@ void py_module(py::module& module) { >>>>>>> one type error left ======= auto py_tensor_to_mesh = static_cast>>( +<<<<<<< HEAD <<<<<<< HEAD module.attr("TensorToMesh")); >>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors ======= module.attr("CppTensorToMesh")); >>>>>>> add back distributed.py for now, clean up class overloads +======= + module.attr("TensorToMesh")); +>>>>>>> interim work for supporting mappers py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); auto py_replicate_tensor_to_mesh = static_cast>>( - module.attr("CppReplicateTensorToMesh")); + module.attr("ReplicateTensorToMesh")); py_replicate_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device) -> std::unique_ptr { @@ -696,7 +748,7 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ReplicateTensorToMesh::config); auto py_shard_tensor_to_mesh = static_cast>>( - module.attr("CppShardTensorToMesh")); + module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { @@ -711,7 +763,7 @@ void py_module(py::module& module) { .def("config", &ShardTensorToMesh::config); auto py_shard_tensor_to_2d_mesh = static_cast>>( - module.attr("CppShardTensorTo2dMesh")); + module.attr("ShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index dce198ccc88..fba567adc09 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -94,13 +94,14 @@ def manage_config(name, value): logger.debug(f"Restored ttnn.CONFIG.{name} to {original_value}") +# apparently the names need to match the types exactly for pybind function arguments, I think a pure python alias would face the same issue from ttnn._ttnn.multi_device import ( MeshDevice, CppMeshToTensor, - CppTensorToMesh, - CppReplicateTensorToMesh, - CppShardTensorToMesh, - CppShardTensorTo2dMesh, + TensorToMesh, + ReplicateTensorToMesh, + ShardTensorToMesh, + ShardTensorTo2dMesh, CppConcatMeshToTensor, CppConcat2dMeshToTensor, ReplicateTensor, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 4901c6ae8cb..32d1109786c 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -6,10 +6,10 @@ from .distributed import ( MeshDevice, DispatchCoreType, - TensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, + # TensorToMesh, + # ShardTensorToMesh, + # ShardTensor2dMesh, + # ReplicateTensorToMesh, MeshToTensor, ConcatMeshToTensor, ConcatMesh2dToTensor, diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index fa057bd0051..0cc33754036 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -238,105 +238,105 @@ def compose(self, tensor: ttnn.Tensor): raise NotImplementedError("Subclasses must implement this method") -class ShardTensorToMesh(TensorToMesh): - def __init__(self, mesh_device, dim): - super().__init__(mesh_device) - self.shard_dim = dim - - def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: - import torch - - sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) - return list(sliced_tensors) - - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.shard_dim}", - } - - -class ShardTensor2dMesh(TensorToMesh): - """ - Shard a tensor across a 2D mesh of devices. - This class implements a strategy for distributing a tensor across a 2D grid of devices, - allowing for efficient parallel processing in distributed computing environments. - """ - - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): - """ - Initialize the ShardTensor2dMesh. - Args: - mesh_device: The target device mesh for distributing the tensor. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: The dimensions to shard along, specified as (row_dim, col_dim). - The `dims` tuple determines how the tensor is sharded across the 2D mesh: - - row_dim: The dimension to shard across mesh rows (or None for replication). - - col_dim: The dimension to shard across mesh columns (or None for replication). - Examples: - 1. dims=(2, 3) for a tensor of shape (A, B, C, D): - - Shard along dimension 2 (C) across mesh rows - - Shard along dimension 3 (D) across mesh columns - 2. dims=(None, 3): - - Replicate across mesh rows - - Shard along dimension 3 (D) across mesh columns - 3. dims=(None, None): - - Fully replicate the tensor across all devices - """ - super().__init__(mesh_device) - self.mesh_shape: Tuple[int, int] = mesh_shape - self.dims: Tuple[Optional[int], Optional[int]] = dims - - mesh_device_rows, mesh_device_cols = self.mesh_device.shape - if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: - raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - - def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: - """ - Map the input tensor to a list of sharded tensors. - Args: - tensor: The input tensor to be sharded. - Returns: - A list of sharded tensors, one for each device in the mesh. - Raises: - ValueError: If the number of sharding dimensions is not 2. - """ - import torch - - if len(self.dims) != 2: - raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims - - # Shard along rows - row_tensors = ( - [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) - ) - - # Shard along columns - if col_dim is None: - return [t.clone() for t in row_tensors for _ in range(cols)] - tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - - if len(tensor_shards) != rows * cols: - raise ValueError( - f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." - ) - - return tensor_shards - - def config(self) -> Dict[str, str]: - """ - Provide the configuration of the sharding strategy. - Returns: - A dictionary containing the sharding strategy and dimensions. - """ - return { - "strategy": "shard_2d", - "mesh_shape_y": str(self.mesh_shape[0]), - "mesh_shape_x": str(self.mesh_shape[1]), - } +# class ShardTensorToMesh(TensorToMesh): +# def __init__(self, mesh_device, dim): +# super().__init__(mesh_device) +# self.shard_dim = dim + +# def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: +# import torch + +# sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) +# return list(sliced_tensors) + +# def config(self): +# return { +# "strategy": "shard", +# "shard_dim": f"{self.shard_dim}", +# } + + +# class ShardTensor2dMesh(TensorToMesh): +# """ +# Shard a tensor across a 2D mesh of devices. +# This class implements a strategy for distributing a tensor across a 2D grid of devices, +# allowing for efficient parallel processing in distributed computing environments. +# """ + +# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): +# """ +# Initialize the ShardTensor2dMesh. +# Args: +# mesh_device: The target device mesh for distributing the tensor. +# mesh_shape: The shape of the 2D mesh as (rows, cols). +# dims: The dimensions to shard along, specified as (row_dim, col_dim). +# The `dims` tuple determines how the tensor is sharded across the 2D mesh: +# - row_dim: The dimension to shard across mesh rows (or None for replication). +# - col_dim: The dimension to shard across mesh columns (or None for replication). +# Examples: +# 1. dims=(2, 3) for a tensor of shape (A, B, C, D): +# - Shard along dimension 2 (C) across mesh rows +# - Shard along dimension 3 (D) across mesh columns +# 2. dims=(None, 3): +# - Replicate across mesh rows +# - Shard along dimension 3 (D) across mesh columns +# 3. dims=(None, None): +# - Fully replicate the tensor across all devices +# """ +# super().__init__(mesh_device) +# self.mesh_shape: Tuple[int, int] = mesh_shape +# self.dims: Tuple[Optional[int], Optional[int]] = dims + +# mesh_device_rows, mesh_device_cols = self.mesh_device.shape +# if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: +# raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") + +# def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: +# """ +# Map the input tensor to a list of sharded tensors. +# Args: +# tensor: The input tensor to be sharded. +# Returns: +# A list of sharded tensors, one for each device in the mesh. +# Raises: +# ValueError: If the number of sharding dimensions is not 2. +# """ +# import torch + +# if len(self.dims) != 2: +# raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") + +# rows, cols = self.mesh_shape +# row_dim, col_dim = self.dims + +# # Shard along rows +# row_tensors = ( +# [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) +# ) + +# # Shard along columns +# if col_dim is None: +# return [t.clone() for t in row_tensors for _ in range(cols)] +# tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] + +# if len(tensor_shards) != rows * cols: +# raise ValueError( +# f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." +# ) + +# return tensor_shards + +# def config(self) -> Dict[str, str]: +# """ +# Provide the configuration of the sharding strategy. +# Returns: +# A dictionary containing the sharding strategy and dimensions. +# """ +# return { +# "strategy": "shard_2d", +# "mesh_shape_y": str(self.mesh_shape[0]), +# "mesh_shape_x": str(self.mesh_shape[1]), +# } class ConcatMesh2dToTensor(MeshToTensor): @@ -397,18 +397,18 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": return torch.cat(row_concatenated, dim=row_dim) -class ReplicateTensorToMesh(TensorToMesh): - def __init__(self, mesh_device: MeshDevice): - super().__init__(mesh_device) +# class ReplicateTensorToMesh(TensorToMesh): +# def __init__(self, mesh_device: MeshDevice): +# super().__init__(mesh_device) - def map(self, tensor: "torch.Tensor"): - return [tensor for i in range(self.mesh_device.get_num_devices())] +# def map(self, tensor: "torch.Tensor"): +# return [tensor for i in range(self.mesh_device.get_num_devices())] - def config(self): - return { - "strategy": "replicate", - "replication_factor": str(self.mesh_device.get_num_devices()), - } +# def config(self): +# return { +# "strategy": "replicate", +# "replication_factor": str(self.mesh_device.get_num_devices()), +# } class ConcatMeshToTensor(MeshToTensor): diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 1529bb328e4..7b34e969acf 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -156,7 +156,7 @@ def from_torch( layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, - mesh_mapper: Optional[Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh]] = None, + mesh_mapper: Optional[ttnn.TensorToMesh] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> ttnn.Tensor: """ @@ -219,9 +219,10 @@ def from_torch( if mesh_mapper: if isinstance(mesh_mapper, ttnn.MeshToTensor): shards = mesh_mapper.map(ttnn.to_torch(tensor)) + tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) else: - shards = mesh_mapper.map(tensor) - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) + # currently failing - I think this path would be easier to do than calling map and then aggregate unless I add borrowedstorage to aggregate though (non-bfloats end up with that type on tensor creation) + tensor = ttnn.distribute_tensor(tensor, mesh_mapper) if tile is not None: tensor = ttnn.Tensor(tensor, dtype, {}, tile) @@ -523,7 +524,7 @@ def as_tensor( memory_config: Optional[ttnn.MemoryConfig] = None, cache_file_name: Optional[Union[str, pathlib.Path]] = None, preprocess: Optional[Callable[[ttnn.Tensor], ttnn.Tensor]] = None, - mesh_mapper: Optional[Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh]] = None, + mesh_mapper: Optional[ttnn.TensorToMesh] = None, use_device_tilizer: bool = False, ) -> ttnn.Tensor: """ From b11660435a62ffe5a05aa6730fd4698bb75ca70e Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 25 Feb 2025 07:55:00 +0000 Subject: [PATCH 43/76] start trying to fix rebase errors --- .../distributed/test_distributed_tensor.py | 103 ++- .../ttnn/distributed/distributed_pybind.cpp | 584 ++---------------- 2 files changed, 122 insertions(+), 565 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 379d2eecfda..260c6705fa2 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -1,43 +1,106 @@ -# # SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +# SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +from loguru import logger +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +from models.utility_functions import nearest_32 + + +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + mapper = ttnn.ReplicateTensorToMesh(mesh_device) + + torch_tensor = torch.randn(1, 1, 32, 256) + replicated_tensors = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper = mapper, + device=mesh_device, + ) + + out_tensors = ttnn.get_device_tensors(replicated_tensors) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass -# # SPDX-License-Identifier: Apache-2.0 +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# def test_direct_shard_to_tensor_mesh(mesh_device, dtype): +# torch.manual_seed(1234) + +# mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) -# import torch -# import pytest -# import ttnn -# from loguru import logger -# from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc -# from models.utility_functions import nearest_32 +# torch_tensor = torch.randn(1, 1, 32, 256) +# sharded_tensor = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# mesh_mapper = mapper, +# device=mesh_device, +# ) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass # @pytest.mark.parametrize( -# "mesh_device", -# [ -# 32, -# ], -# indirect=True, +# "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +# ) +# @pytest.mark.parametrize( +# "M, K, N", +# [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], # ) # @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): +# def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): # torch.manual_seed(1234) -# mapper = ttnn.ReplicateTensorToMesh(mesh_device) +# torch_tensor = torch.randn(1, 1, M, K) +# core_grid = ttnn.CoreGrid(y=1, x=8) -# torch_tensor = torch.randn(1, 1, 32, 256) -# replicated_tensors = ttnn.from_torch( +# # If K < N it's FF1-like test case, else FF2-like test case +# shard_dim = (0, 3) if K < N else (3, 0) + +# K = K // mesh_shape[1] if K < N else K // mesh_shape[0] +# N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + +# sharded_mem_config = ttnn.create_sharded_memory_config( +# shape=(M // core_grid.y, K // core_grid.x), +# core_grid=core_grid, +# strategy=ttnn.ShardStrategy.WIDTH, +# orientation=ttnn.ShardOrientation.ROW_MAJOR, +# use_height_and_width_as_shard_shape=True, +# ) + +# mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + +# sharded_tensor = ttnn.from_torch( # torch_tensor, # dtype=dtype, # layout=ttnn.TILE_LAYOUT, +# memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, # mesh_mapper = mapper, # device=mesh_device, # ) -# out_tensors = ttnn.get_device_tensors(replicated_tensors) - -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) # logger.info(f"PCC value: {out_pcc}") # assert out_pass + # @pytest.mark.parametrize( # "mesh_device", # [ diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index ff90c1d14b5..067a7459e2d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -3,88 +3,20 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/distributed/distributed_pybind.hpp" -<<<<<<< HEAD -<<<<<<< HEAD #include -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -#include -======= -======= -#include -<<<<<<< HEAD ->>>>>>> one type error left -#include -#include -#include ->>>>>>> expose classes to python -======= -<<<<<<< HEAD -#include -#include -#include -======= ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu - -======= #include ->>>>>>> fix rebase -======= -#include -======= ->>>>>>> remove unused import - -#include ->>>>>>> fix rebase #include -<<<<<<< HEAD #include "tt-metalium/mesh_coord.hpp" #include "distributed_tensor.hpp" #include "tt-metalium/assert.hpp" -<<<<<<< HEAD -======= -#include "distributed_tensor.hpp" -<<<<<<< HEAD -<<<<<<< HEAD ->>>>>>> expose classes to python -#include "ttnn/distributed/api.hpp" -<<<<<<< HEAD -#include "ttnn/distributed/types.hpp" -======= -#include "ttnn/distributed/distributed_tensor_config.hpp" -#include "ttnn/distributed/types.hpp" -#include "ttnn/operations/core/core.hpp" -<<<<<<< HEAD -<<<<<<< HEAD -#include "ttnn/tensor/tensor_utils.hpp" ->>>>>>> one type error left -======= ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu -======= -#include "distributed_tensor.cpp" -======= ->>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors -======= ->>>>>>> fix test mappers, convert to cpu_tensor #include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/distributed/types.hpp" #include "ttnn/operations/core/core.hpp" -#include "ttnn/tensor/tensor_utils.hpp" ->>>>>>> one type error left #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_impl_wrapper.hpp" #include -======= -#include "ttnn/tensor/tensor.hpp" -#include -<<<<<<< HEAD ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu -======= ->>>>>>> fix rebase // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -96,55 +28,6 @@ namespace ttnn::distributed { namespace py = pybind11; -// namespace pybind11 { namespace detail { - -// // Helper template that implements conversion from std::unique_ptr to std::unique_ptr -// template -// struct unique_ptr_base_caster { -// // This macro defines the member "value" (of type std::unique_ptr) and a type annotation. -// PYBIND11_TYPE_CASTER(std::unique_ptr, -// py::str("unique_ptr<") + type_id() + py::str(">")); - -// // load() is not supported for unique_ptr conversion from Python. -// bool load(::pybind11::handle, bool) { -// return false; -// } - -// // cast() converts the C++ unique_ptr to a Python object. -// static ::pybind11::handle cast(const std::unique_ptr& src, -// ::pybind11::return_value_policy policy, ::pybind11::handle parent) { -// // Convert the underlying raw pointer from Derived* to Base* -// Base* base_ptr = static_cast(src.get()); -// // Use py::cast on the raw pointer. Note: this does not transfer ownership. -// // The returned handle is then released. -// return py::cast(base_ptr, policy, parent).release(); -// } -// }; - -// }} // namespace pybind11::detail -// namespace pybind11 { namespace detail { - -// template <> -// struct ::pybind11::detail::type_caster> -// : unique_ptr_base_caster {}; - -// template <> -// struct ::pybind11::detail::type_caster> -// : unique_ptr_base_caster {}; - -// template <> -// struct ::pybind11::detail::type_caster> -// : unique_ptr_base_caster {}; - -// struct ::pybind11::detail::type_caster> -// : unique_ptr_base_caster {}; - -// template <> -// struct ::pybind11::detail::type_caster> -// : unique_ptr_base_caster {}; - -// }} // namespace pybind11::detail - // Trampoline class to clear virtual method errors struct ConcreteTensorToMesh : TensorToMesh { using TensorToMesh::TensorToMesh; // Inherit constructors @@ -166,67 +49,17 @@ struct ConcreteMeshToTensor : MeshToTensor { }; void py_module_types(py::module& module) { -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD py::class_>(module, "CppMeshToTensor"); py::class_>(module, "TensorToMesh"); py::class_>( module, "ReplicateTensorToMesh"); py::class_>(module, "ShardTensorToMesh"); - py::class_>(module, "ShardTensorTo2dMesh"); + py::class_>( + module, "ShardTensor2dMesh"); py::class_>(module, "CppConcatMeshToTensor"); - py::class_>( - module, "CppConcat2dMeshToTensor"); - - py::class_(module, "ReplicateTensor"); - py::class_(module, "ShardTensor"); - py::class_(module, "ShardTensor2d"); - py::class_(module, "ShardMesh"); - py::class_(module, "AllGatherTensor"); - py::class_(module, "DistributedTensorConfig"); - - py::class_(module, "Shard2dConfig"); - py::class_(module, "Concat2dConfig"); -======= - py::class_>(module, "MeshToTensor"); - py::class_>(module, "TensorToMesh"); - py::class_(module, "TensorToMesh"); - py::class_(module, "ShardTensorToMesh"); - py::class_(module, "ShardTensorTo2dMesh"); - py::class_(module, "ConcatMeshToTensor"); - py::class_(module, "Concat2dMeshToTensor"); ->>>>>>> expose classes to python -======= - py::class_>(module, "MeshToTensor"); - py::class_>(module, "TensorToMesh"); - py::class_>( - module, "ReplicateTensorToMesh"); - py::class_>(module, "ShardTensorToMesh"); - py::class_>(module, "ShardTensor2dMesh"); - py::class_>(module, "ConcatMeshToTensor"); -<<<<<<< HEAD - py::class_>( - module, "Concat2dMeshToTensor"); ->>>>>>> one type error left -======= py::class_>( module, "ConcatMesh2dToTensor"); ->>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice -======= - py::class_>(module, "CppMeshToTensor"); - py::class_>(module, "CppTensorToMesh"); - - py::class_>( - module, "CppReplicateTensorToMesh"); - py::class_>(module, "CppShardTensorToMesh"); - py::class_>( - module, "CppShardTensorTo2dMesh"); - py::class_>(module, "CppConcatMeshToTensor"); - py::class_>( - module, "CppConcat2dMeshToTensor"); ->>>>>>> add back distributed.py for now, clean up class overloads py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); @@ -246,48 +79,28 @@ void py_module_types(py::module& module) { } void py_module(py::module& module) { - // TODO: #17477 - Remove overloads that accept 'row' and 'col'. Instead, use generic ND terms. static_cast>(module.attr("MeshShape")) .def( py::init([](size_t num_rows, size_t num_cols) { return MeshShape(num_rows, num_cols); }), - "Constructor with the specified number of rows and columns.", + "Constructor with specified number of rows and columns.", py::arg("num_rows"), py::arg("num_cols")) .def( -<<<<<<< HEAD -<<<<<<< HEAD - py::init([](size_t x, size_t y, size_t z) { return MeshShape(x, y, z); }), - "Constructor with the specified 3D shape.", - py::arg("x"), - py::arg("y"), - py::arg("z")) - .def( - py::init([](const std::vector& shape) { return MeshShape(shape); }), - "Constructor with the specified ND shape.", - py::arg("shape")) -======= -======= ->>>>>>> Replace none types, expose configs, fix tuple errors py::init([](const std::tuple& dims) { return MeshShape(std::get<0>(dims), std::get<1>(dims)); }), "Constructor with specified number of rows and columns as a tuple (rows, columns).", py::arg("dims")) .def_readwrite("num_rows", &MeshShape::num_rows, "Number of rows in the mesh.") .def_readwrite("num_cols", &MeshShape::num_cols, "Number of columns in the mesh.") ->>>>>>> Replace none types, expose configs, fix tuple errors .def( "__repr__", [](const MeshShape& ms) { - std::ostringstream str; - str << ms; - return str.str(); + return ""; }) + .def("__iter__", [](const MeshShape& ms) { return py::iter(py::make_tuple(ms.num_rows, ms.num_cols)); }); + static_cast>(module.attr("MeshOffset")) .def( - "__iter__", - [](const MeshShape& ms) { return py::make_iterator(ms.view().begin(), ms.view().end()); }, - py::keep_alive<0, 1>()); - static_cast>(module.attr("MeshCoordinate")) - .def( - py::init([](size_t row, size_t col) { return MeshCoordinate(row, col); }), + py::init([](size_t row, size_t col) { return MeshOffset(row, col); }), "Constructor with specified row and column offsets.", py::arg("row"), py::arg("col")) @@ -339,17 +152,19 @@ void py_module(py::module& module) { auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( - py::init([](const MeshShape& mesh_shape, + py::init([](const MeshShape& mesh_device_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, - const std::optional& offset, + const MeshOffset& offset, const std::vector& physical_device_ids) { return MeshDevice::create( MeshDeviceConfig{ - .mesh_shape = mesh_shape, - .offset = offset, + .mesh_shape = SimpleMeshShape(mesh_device_shape), + .offset = offset.row != 0 || offset.col != 0 + ? std::make_optional(offset.row, offset.col) + : std::nullopt, .physical_device_ids = physical_device_ids, }, l1_small_size, @@ -584,151 +399,8 @@ void py_module(py::module& module) { back to all SubDevice IDs. )doc"); -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD auto py_tensor_to_mesh = static_cast>>( module.attr("CppTensorToMesh")); - py_tensor_to_mesh - .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) - .def("map", &TensorToMesh::map) - .def("config", &TensorToMesh::config); - auto py_replicate_tensor_to_mesh = - static_cast>>( - module.attr("CppReplicateTensorToMesh")); - py_replicate_tensor_to_mesh - .def( - py::init([](MeshDevice& mesh_device) -> std::unique_ptr { - return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); - }), - py::arg("mesh_device")) - .def( - "map", - [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, - py::arg("tensor")) - .def("config", &ReplicateTensorToMesh::config); - auto py_shard_tensor_to_mesh = static_cast>>( - module.attr("CppShardTensorToMesh")); - py_shard_tensor_to_mesh - .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(ShardTensorToMesh(mesh_device.num_devices(), dim)); - }), - py::arg("mesh_device"), - py::arg("dim")) - .def( - "map", - [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, - py::arg("tensor")) - .def("config", &ShardTensorToMesh::config); - auto py_shard_tensor_to_2d_mesh = - static_cast>>( - module.attr("CppShardTensorTo2dMesh")); - py_shard_tensor_to_2d_mesh - .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { - int shape_rows = std::get<0>(mesh_shape); - int shape_cols = std::get<1>(mesh_shape); - - int config_rows = std::get<0>(dims); - int config_cols = std::get<1>(dims); - TT_FATAL( - config_rows || config_cols, - "Sharding a tensor to 2D mesh requires at least one dimension to shard"); - TT_FATAL( - shape_rows <= mesh_device.shape().num_rows && // - shape_cols <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - - return std::make_unique( - MeshShape{.num_rows = shape_rows, .num_cols = shape_cols}, - Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); - }), - py::arg("mesh_device"), - py::arg("mesh_shape"), - py::arg("dims")) - .def( - "map", - [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, - py::arg("tensor")) - .def("config", &ShardTensorTo2dMesh::config); - auto py_mesh_to_tensor = static_cast>>( - module.attr("CppMeshToTensor")); - py_mesh_to_tensor - .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) - .def("compose", &MeshToTensor::compose); - auto py_concat_mesh_to_tensor = static_cast>>( - module.attr("CppConcatMeshToTensor")); - py_concat_mesh_to_tensor - .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(dim); - }), - py::arg("mesh_device"), - py::arg("dim")) - .def( - "compose", - [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, - py::arg("tensors")); - - auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("CppConcat2dMeshToTensor")); - py_concat_2d_mesh_to_tensor - .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { - int row_dim = std::get<0>(dims); - int col_dim = std::get<1>(dims); - TT_FATAL( - std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // - std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - - TT_FATAL( - row_dim != col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - row_dim, - col_dim); - return std::make_unique( - mesh_device, - Concat2dConfig{ - .row_dim = row_dim, - .col_dim = col_dim, - }); - }), - py::arg("mesh_device"), - py::arg("Mesh_shape"), - py::arg("dims")) - .def( - "compose", - [](const Concat2dMeshToTensor& self, const std::vector& tensors) -> Tensor { - return self.compose(tensors); - }, - py::arg("tensors")); -======= - auto py_tensor_to_mesh = static_cast>>(module.attr("TensorToMesh")); -======= - auto py_tensor_to_mesh = - static_cast>>(module.attr("TensorToMesh")); ->>>>>>> one type error left -======= - auto py_tensor_to_mesh = static_cast>>( -<<<<<<< HEAD -<<<<<<< HEAD - module.attr("TensorToMesh")); ->>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors -======= - module.attr("CppTensorToMesh")); ->>>>>>> add back distributed.py for now, clean up class overloads -======= - module.attr("TensorToMesh")); ->>>>>>> interim work for supporting mappers py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) @@ -762,14 +434,14 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ShardTensorToMesh::config); auto py_shard_tensor_to_2d_mesh = - static_cast>>( - module.attr("ShardTensorTo2dMesh")); + static_cast>>( + module.attr("ShardTensor2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { + const std::tuple dims) -> std::unique_ptr { int shape_rows = std::get<0>(mesh_shape); int shape_cols = std::get<1>(mesh_shape); @@ -783,7 +455,7 @@ void py_module(py::module& module) { shape_cols <= mesh_device.shape().num_cols, "Device mesh shape does not match the provided mesh shape."); - return std::make_unique( + return std::make_unique( MeshShape{.num_rows = shape_rows, .num_cols = shape_cols}, Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), @@ -792,9 +464,9 @@ void py_module(py::module& module) { py::arg("dims")) .def( "map", - [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) - .def("config", &ShardTensorTo2dMesh::config); + .def("config", &ShardTensor2dMesh::config); auto py_mesh_to_tensor = static_cast>>( module.attr("CppMeshToTensor")); py_mesh_to_tensor @@ -815,27 +487,14 @@ void py_module(py::module& module) { py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("CppConcat2dMeshToTensor")); + static_cast>>( + module.attr("ConcatMesh2dToTensor")); py_concat_2d_mesh_to_tensor -<<<<<<< HEAD - .def(py::init<>(MeshDevice & mesh_device, const Concat2dConfig& config) { - return concat_2d_mesh_to_tensor_composer(mesh_device, config); - }, - py::kw_only(), - py::arg("mesh_device"), - py::arg("config")) - .def("compose",[](self, const std::vector& tensors) { - return self.compose(tensors); - }, - .py::arg("tensors")); ->>>>>>> expose classes to python -======= .def( py::init( [](MeshDevice& mesh_device, const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { + const std::tuple dims) -> std::unique_ptr { int row_dim = std::get<0>(dims); int col_dim = std::get<1>(dims); TT_FATAL( @@ -848,7 +507,7 @@ void py_module(py::module& module) { "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", row_dim, col_dim); - return std::make_unique( + return std::make_unique( mesh_device, Concat2dConfig{ .row_dim = row_dim, @@ -860,11 +519,10 @@ void py_module(py::module& module) { py::arg("dims")) .def( "compose", - [](const Concat2dMeshToTensor& self, const std::vector& tensors) -> Tensor { + [](const ConcatMesh2dToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); ->>>>>>> one type error left module.def( "open_mesh_device", @@ -898,15 +556,7 @@ void py_module(py::module& module) { auto py_replicate_tensor_config = static_cast>(module.attr("ShardTensor")); py_replicate_tensor_config.def(py::init<>()) .def(py::init(), py::arg("replication_factor") = 1) -<<<<<<< HEAD -<<<<<<< HEAD .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) -======= - .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) ->>>>>>> add configs to pybind -======= - .def_readwrite("shard_dimension", &ReplicateTensor::replication_factor) ->>>>>>> fix test mappers, convert to cpu_tensor .def("__eq__", [](const ReplicateTensor& a, const ReplicateTensor& b) { return a.replication_factor == b.replication_factor; }); @@ -915,8 +565,6 @@ void py_module(py::module& module) { py_shard_tensor_config.def(py::init(), py::arg("shard_dimension")) .def_readwrite("shard_dimension", &ShardTensor::shard_dimension) .def("__eq__", [](const ShardTensor& a, const ShardTensor& b) { return a == b; }); -<<<<<<< HEAD -<<<<<<< HEAD auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); @@ -936,39 +584,6 @@ void py_module(py::module& module) { py_concat2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) .def_readwrite("row_dim", &Concat2dConfig::row_dim) .def_readwrite("col_dim", &Concat2dConfig::col_dim); -======= - -======= ->>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors - auto py_shard_mesh = static_cast>(module.attr("ShardMesh")); - py_shard_mesh.def(py::init<>()).def_readwrite("y", &ShardMesh::y).def_readwrite("x", &ShardMesh::x); - auto py_shard_tensor2d = static_cast>(module.attr("ShardTensor2d")); - py_shard_tensor2d.def(py::init(), py::arg("mesh")) - .def_readonly("shard_mesh", &ShardTensor2D::shard_mesh) - .def("__eq__", [](const ShardTensor2D& a, const ShardTensor2D& b) { return a == b; }); -<<<<<<< HEAD - -<<<<<<< HEAD - auto py_allgather_config = static_cast>(module.attr("AllGatherTensor")); - .def(py::init<>()).def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); ->>>>>>> add configs to pybind -======= -======= ->>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors - auto py_allgather_config = - static_cast>(module.attr("AllGatherTensor")) - .def(py::init<>()) - .def("__eq__", [](const AllGatherTensor& a, const AllGatherTensor& b) { return a == b; }); ->>>>>>> fix test mappers, convert to cpu_tensor - - auto py_shard2d_config = static_cast>(module.attr("Shard2dConfig")); - py_shard2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) - .def_readwrite("row_dim", &Shard2dConfig::row_dim) - .def_readwrite("col_dim", &Shard2dConfig::col_dim); - auto py_concat2d_config = static_cast>(module.attr("Concat2dConfig")); - py_concat2d_config.def(py::init(), py::arg("row_dim"), py::arg("col_dim")) - .def_readwrite("row_dim", &Concat2dConfig::row_dim) - .def_readwrite("col_dim", &Concat2dConfig::col_dim); module.def( "get_distributed_tensor_config", @@ -982,31 +597,6 @@ void py_module(py::module& module) { "item": "field", } )doc"); -<<<<<<< HEAD - module.def( - "get_shard2d_config", - &get_shard2d_config, - py::arg("metadata"), - R"doc( - Returns a Shard2dConfig object given a valid metadata object of the type - { - "row_dim": "field", - "col_dim": "field", - } - )doc"); - module.def( - "get_concat2d_config", - &get_concat2d_config, - py::arg("metadata"), - R"doc( - Returns a Concat2dConfig object given a valid metadata object of the type - { - "row_dim": "field", - "col_dim": "field", - } - )doc"); -======= ->>>>>>> add configs to pybind module.def( "get_shard2d_config", &get_shard2d_config, @@ -1038,14 +628,6 @@ void py_module(py::module& module) { R"doc( Get the tensor shard corresponding to the device. -<<<<<<< HEAD -<<<<<<< HEAD -======= -<<<<<<< HEAD - ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu -======= ->>>>>>> fix rebase Args: tensor (Tensor): The tensor to get the shard from. device (Device): The device to get the shard for. @@ -1055,16 +637,7 @@ void py_module(py::module& module) { Tensor: The shard of the tensor corresponding to the device. )doc"); module.def("get_device_tensors", &get_device_tensors, py::arg("tensor"), py::kw_only()); -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD // TODO: Add rdocs -======= ->>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors -======= - // TODO: Add rdocs ->>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error module.def( "replicate_tensor_to_mesh_mapper", [](MeshDevice& mesh_device) -> std::unique_ptr { @@ -1089,10 +662,6 @@ void py_module(py::module& module) { py::arg("mesh_shape"), py::arg("config")); module.def( -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> Replace none types, expose configs, fix tuple errors "shard_tensor_to_2d_mesh_mapper", [](MeshDevice& mesh_device, const std::tuple mesh_shape, @@ -1117,11 +686,6 @@ void py_module(py::module& module) { TensorToMesh: The created ShardTensor2dMesh mapper. )doc"); module.def( -<<<<<<< HEAD -======= ->>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors -======= ->>>>>>> Replace none types, expose configs, fix tuple errors "concat_mesh_to_tensor_composer", [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, py::arg("dim")); @@ -1132,36 +696,6 @@ void py_module(py::module& module) { }, py::arg("mesh_device"), py::arg("config")); -<<<<<<< HEAD -<<<<<<< HEAD - module.def( - "concat_2d_mesh_to_tensor_composer", - [](MeshDevice& mesh_device, - const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { - TT_FATAL( - std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // - std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, - "Device mesh shape does not match the provided mesh shape."); - return concat_2d_mesh_to_tensor_composer( - mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); - }, - py::arg("mesh_device"), - py::arg("dims"), - py::arg("mesh_shape"), - R"doc( - Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. - - Args: - mesh_device (MeshDevice): The mesh device to create the composer for. - dims (Tuple[int, int]): The dimensions to create the composer for in (row, column) format. - mesh_shape (Tuple[int, int]): The shape of the 2D mesh as (num_rows, num_cols). - - Returns: - TensorToMesh: The created ConcatMesh2dToTensor composer. - )doc"); -======= ->>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error module.def( "concat_2d_mesh_to_tensor_composer", [](MeshDevice& mesh_device, @@ -1193,74 +727,34 @@ void py_module(py::module& module) { [](const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device) -> Tensor { -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD - return distribute_tensor(from_device(tensor), mapper, mesh_device); - }, - py::arg("tensor"), - py::arg("mapper"), - py::arg("mesh_device")); - module.def( - "aggregate_tensor", - [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { - return aggregate_tensor(from_device(tensor), composer); - }, -======= - return distribute_tensor(tensor, mapper, mesh_device); -======= - return distribute_tensor(get_cpu_tensor(tensor), mapper, mesh_device); ->>>>>>> fix test mappers, convert to cpu_tensor -======= - return distribute_tensor(from_device(tensor), mapper, mesh_device); ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu + Tensor cpu_tensor; + if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); + } else { + cpu_tensor = from_device(tensor); + } + + return distribute_tensor(cpu_tensor, mapper, mesh_device); }, py::arg("tensor"), py::arg("mapper"), py::arg("mesh_device")); module.def( "aggregate_tensor", -<<<<<<< HEAD - [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { return aggregate_tensor(tensor, composer); }, ->>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error -======= [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { - return aggregate_tensor(from_device(tensor), composer); + Tensor cpu_tensor = from_device(tensor); + return aggregate_tensor(cpu_tensor, composer); }, ->>>>>>> fix test mappers, convert to cpu_tensor py::arg("tensor"), py::arg("composer")); module.def( "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD - return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); - }, - py::arg("tensor"), - py::arg("composer")); -======= - //TODO: overload this method to enable selection of a subset of shards with a config or something before passing to aggregate ->>>>>>> expose classes to python -======= -======= ->>>>>>> move class definitions from from distributed_tensor.cpp to.hpp so they can be exposed to the pybind.cpp; add dummy void methods in .cpp to satisfy linker; add new constructors and factory methods to fix type errors - // TODO: overload this method to enable selection of a subset of shards with a config or something before passing to - // aggregate ->>>>>>> one type error left -======= - return aggregate_tensor(aggregate_as_tensor(tensors, AllGatherTensor{}), composer); -======= - return aggregate_tensor(get_cpu_tensor(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); ->>>>>>> fix test mappers, convert to cpu_tensor -======= - return aggregate_tensor(from_device(aggregate_as_tensor(tensors, AllGatherTensor{})), composer); ->>>>>>> clean up imports, fix test cases and change them to use mapper/composer functions, fix storage error by using from_device instead of cpu + Tensor cpu_tensor = from_device(aggregate_as_tensor(tensors, AllGatherTensor{})); + return aggregate_tensor(cpu_tensor, composer); }, py::arg("tensor"), py::arg("composer")); ->>>>>>> fix mesh device conflict, add aggregate/distribute and config pybinds, fix keyword error module.def( "aggregate_as_tensor", [](const std::vector& tensors) -> Tensor { return aggregate_as_tensor(tensors, AllGatherTensor{}); }, @@ -1269,4 +763,4 @@ void py_module(py::module& module) { module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } -} // namespace ttnn::distributed +} // namespace ttnn::distributed \ No newline at end of file From 28cdd3be08ec08da0caef0cfb04d91d50cc92a25 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 25 Feb 2025 16:04:50 +0000 Subject: [PATCH 44/76] fix rebase errors --- .../ttnn/distributed/distributed_pybind.cpp | 124 +++++++++--------- .../ttnn/distributed/distributed_tensor.cpp | 48 ------- .../ttnn/distributed/distributed_tensor.hpp | 44 +++---- 3 files changed, 83 insertions(+), 133 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 067a7459e2d..53ad59cdb8d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -1,22 +1,19 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - #include "ttnn/distributed/distributed_pybind.hpp" +#include #include - +#include +#include #include -#include #include "tt-metalium/mesh_coord.hpp" -#include "distributed_tensor.hpp" #include "tt-metalium/assert.hpp" +#include "distributed_tensor.hpp" #include "ttnn/distributed/api.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" #include "ttnn/distributed/types.hpp" #include "ttnn/operations/core/core.hpp" -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/tensor/tensor_impl_wrapper.hpp" +#include "ttnn/tensor/tensor_utils.hpp" #include +#include "ttnn/tensor/tensor_impl_wrapper.hpp" // This is required for automatic conversions, as in the creation of mesh devices // https://github.com/tenstorrent/tt-metal/issues/18082 @@ -55,11 +52,10 @@ void py_module_types(py::module& module) { py::class_>( module, "ReplicateTensorToMesh"); py::class_>(module, "ShardTensorToMesh"); - py::class_>( - module, "ShardTensor2dMesh"); + py::class_>(module, "ShardTensorTo2dMesh"); py::class_>(module, "CppConcatMeshToTensor"); - py::class_>( - module, "ConcatMesh2dToTensor"); + py::class_>( + module, "Concat2dMeshToTensor"); py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); @@ -79,28 +75,37 @@ void py_module_types(py::module& module) { } void py_module(py::module& module) { + // TODO: #17477 - Remove overloads that accept 'row' and 'col'. Instead, use generic ND terms. static_cast>(module.attr("MeshShape")) .def( py::init([](size_t num_rows, size_t num_cols) { return MeshShape(num_rows, num_cols); }), - "Constructor with specified number of rows and columns.", + "Constructor with the specified number of rows and columns.", py::arg("num_rows"), py::arg("num_cols")) .def( - py::init([](const std::tuple& dims) { return MeshShape(std::get<0>(dims), std::get<1>(dims)); }), - "Constructor with specified number of rows and columns as a tuple (rows, columns).", - py::arg("dims")) - .def_readwrite("num_rows", &MeshShape::num_rows, "Number of rows in the mesh.") - .def_readwrite("num_cols", &MeshShape::num_cols, "Number of columns in the mesh.") + py::init([](size_t x, size_t y, size_t z) { return MeshShape(x, y, z); }), + "Constructor with the specified 3D shape.", + py::arg("x"), + py::arg("y"), + py::arg("z")) + .def( + py::init([](const std::vector& shape) { return MeshShape(shape); }), + "Constructor with the specified ND shape.", + py::arg("shape")) .def( "__repr__", [](const MeshShape& ms) { - return ""; + std::ostringstream str; + str << ms; + return str.str(); }) - .def("__iter__", [](const MeshShape& ms) { return py::iter(py::make_tuple(ms.num_rows, ms.num_cols)); }); - static_cast>(module.attr("MeshOffset")) .def( - py::init([](size_t row, size_t col) { return MeshOffset(row, col); }), + "__iter__", + [](const MeshShape& ms) { return py::make_iterator(ms.view().begin(), ms.view().end()); }, + py::keep_alive<0, 1>()); + static_cast>(module.attr("MeshCoordinate")) + .def( + py::init([](size_t row, size_t col) { return MeshCoordinate(row, col); }), "Constructor with specified row and column offsets.", py::arg("row"), py::arg("col")) @@ -152,19 +157,17 @@ void py_module(py::module& module) { auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( - py::init([](const MeshShape& mesh_device_shape, + py::init([](const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, const DispatchCoreConfig& dispatch_core_config, - const MeshOffset& offset, + const std::optional& offset, const std::vector& physical_device_ids) { return MeshDevice::create( MeshDeviceConfig{ - .mesh_shape = SimpleMeshShape(mesh_device_shape), - .offset = offset.row != 0 || offset.col != 0 - ? std::make_optional(offset.row, offset.col) - : std::nullopt, + .mesh_shape = mesh_shape, + .offset = offset, .physical_device_ids = physical_device_ids, }, l1_small_size, @@ -434,16 +437,16 @@ void py_module(py::module& module) { py::arg("tensor")) .def("config", &ShardTensorToMesh::config); auto py_shard_tensor_to_2d_mesh = - static_cast>>( - module.attr("ShardTensor2dMesh")); + static_cast>>( + module.attr("ShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { - int shape_rows = std::get<0>(mesh_shape); - int shape_cols = std::get<1>(mesh_shape); + const std::tuple dims) -> std::unique_ptr { + int mesh_rows = std::get<0>(mesh_shape); + int mesh_cols = std::get<1>(mesh_shape); int config_rows = std::get<0>(dims); int config_cols = std::get<1>(dims); @@ -451,12 +454,13 @@ void py_module(py::module& module) { config_rows || config_cols, "Sharding a tensor to 2D mesh requires at least one dimension to shard"); TT_FATAL( - shape_rows <= mesh_device.shape().num_rows && // - shape_cols <= mesh_device.shape().num_cols, + mesh_rows <= mesh_device.shape()[0] && // + mesh_cols <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); - return std::make_unique( - MeshShape{.num_rows = shape_rows, .num_cols = shape_cols}, + return std::make_unique( + mesh_rows, + mesh_cols, Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), py::arg("mesh_device"), @@ -464,9 +468,9 @@ void py_module(py::module& module) { py::arg("dims")) .def( "map", - [](const ShardTensor2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, + [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) - .def("config", &ShardTensor2dMesh::config); + .def("config", &ShardTensorTo2dMesh::config); auto py_mesh_to_tensor = static_cast>>( module.attr("CppMeshToTensor")); py_mesh_to_tensor @@ -487,19 +491,19 @@ void py_module(py::module& module) { py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("ConcatMesh2dToTensor")); + static_cast>>( + module.attr("Concat2dMeshToTensor")); py_concat_2d_mesh_to_tensor .def( py::init( [](MeshDevice& mesh_device, const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { + const std::tuple dims) -> std::unique_ptr { int row_dim = std::get<0>(dims); int col_dim = std::get<1>(dims); TT_FATAL( - std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // - std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, + std::get<0>(mesh_shape) <= mesh_device.shape()[0] && // + std::get<1>(mesh_shape) <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); TT_FATAL( @@ -507,8 +511,9 @@ void py_module(py::module& module) { "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", row_dim, col_dim); - return std::make_unique( - mesh_device, + return std::make_unique( + row_dim, + col_dim, Concat2dConfig{ .row_dim = row_dim, .col_dim = col_dim, @@ -519,7 +524,7 @@ void py_module(py::module& module) { py::arg("dims")) .def( "compose", - [](const ConcatMesh2dToTensor& self, const std::vector& tensors) -> Tensor { + [](const Concat2dMeshToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); @@ -675,7 +680,7 @@ void py_module(py::module& module) { py::arg("mesh_shape"), py::arg("dims"), R"doc( - Create a ShardTensor2dMesh mapper with the given mesh device, mesh shape, and dimensions. + Create a ShardTensorTo2dMesh mapper with the given mesh device, mesh shape, and dimensions. Args: mesh_device (MeshDevice): The mesh device to create the mapper for. @@ -683,27 +688,20 @@ void py_module(py::module& module) { dims (Tuple[int, int]): The dimensions to create the mapper for in (row, column) format. Returns: - TensorToMesh: The created ShardTensor2dMesh mapper. + TensorToMesh: The created ShardTensorTo2dMesh mapper. )doc"); module.def( "concat_mesh_to_tensor_composer", [](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }, py::arg("dim")); - module.def( - "concat_2d_mesh_to_tensor_composer", - [](MeshDevice& mesh_device, const Concat2dConfig& config) -> std::unique_ptr { - return concat_2d_mesh_to_tensor_composer(mesh_device, config); - }, - py::arg("mesh_device"), - py::arg("config")); module.def( "concat_2d_mesh_to_tensor_composer", [](MeshDevice& mesh_device, const std::tuple mesh_shape, const std::tuple dims) -> std::unique_ptr { TT_FATAL( - std::get<0>(mesh_shape) <= mesh_device.shape().num_rows && // - std::get<1>(mesh_shape) <= mesh_device.shape().num_cols, + std::get<0>(mesh_shape) <= mesh_device.shape()[0] && // + std::get<1>(mesh_shape) <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); return concat_2d_mesh_to_tensor_composer( mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); @@ -712,7 +710,7 @@ void py_module(py::module& module) { py::arg("dims"), py::arg("mesh_shape"), R"doc( - Create a ConcatMesh2dToTensor composer with the given mesh device and dimensions. + Create a Concat2dMeshToTensor composer with the given mesh device and dimensions. Args: mesh_device (MeshDevice): The mesh device to create the composer for. @@ -720,7 +718,7 @@ void py_module(py::module& module) { mesh_shape (Tuple[int, int]): The shape of the 2D mesh as (num_rows, num_cols). Returns: - TensorToMesh: The created ConcatMesh2dToTensor composer. + TensorToMesh: The created Concat2dMeshToTensor composer. )doc"); module.def( "distribute_tensor", @@ -763,4 +761,4 @@ void py_module(py::module& module) { module.def("get_t3k_physical_device_ids_ring", &get_t3k_physical_device_ids_ring); } -} // namespace ttnn::distributed \ No newline at end of file +} // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 50449ebf7fe..da12dd7a40b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -161,23 +161,7 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( mesh_shape[0] <= mesh_device.shape()[0] && // mesh_shape[1] <= mesh_device.shape()[1], "Device mesh shape does not match the provided mesh shape."); -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD return std::make_unique(mesh_shape[0], mesh_shape[1], config); -======= - return std::make_unique(mesh_shape, config); ->>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice -======= - return std::make_unique(mesh_shape, config); ->>>>>>> add back distributed.py for now, clean up class overloads -======= - return std::make_unique(mesh_shape, config); ->>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice -======= - return std::make_unique(mesh_shape, config); ->>>>>>> add back distributed.py for now, clean up class overloads } std::unique_ptr concat_mesh_to_tensor_composer(int dim) { @@ -190,24 +174,8 @@ std::unique_ptr concat_2d_mesh_to_tensor_composer(MeshDevice& mesh "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", config.row_dim, config.col_dim); -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD TT_FATAL(mesh_device.shape().dims() == 2, "Mesh device is not configured as a 2D mesh: {}", mesh_device.shape()); return std::make_unique(mesh_device.shape()[0], mesh_device.shape()[1], config); -======= - return std::make_unique(mesh_device, config); ->>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice -======= - return std::make_unique(mesh_device, config); ->>>>>>> add back distributed.py for now, clean up class overloads -======= - return std::make_unique(mesh_device, config); ->>>>>>> fix naming errors, add tests, add imports - TODO, fix weird aliasing error with meshdevice vs ttnn.multidevice.meshdevice -======= - return std::make_unique(mesh_device, config); ->>>>>>> add back distributed.py for now, clean up class overloads } Tensor distribute_tensor( @@ -230,27 +198,11 @@ Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { : composer.compose({tensor}); } -<<<<<<< HEAD -<<<<<<< HEAD Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { -======= -static Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { - return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); -} - -static Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { ->>>>>>> add shard2dconfig, concat2dconfig methods and map/compose constructors -======= -Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { - return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); -} - -Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { ->>>>>>> Replace none types, expose configs, fix tuple errors return Concat2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); } diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index 806168721e9..d7cf747492c 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -85,58 +85,58 @@ class ShardTensorToMesh : public TensorToMesh { class ShardTensorTo2dMesh : public TensorToMesh { public: - ShardTensorTo2dMesh(const MeshShape& mesh_shape, const Shard2dConfig& config) : - mesh_shape_(mesh_shape), config_(config) {} + ShardTensorTo2dMesh(size_t mesh_rows, size_t mesh_cols, const Shard2dConfig& config) : + mesh_rows_(mesh_rows), mesh_cols_(mesh_cols), config_(config) {} std::vector map(const Tensor& tensor) const override { - const auto [rows, cols] = mesh_shape_; const auto [row_dim, col_dim] = config_; std::vector row_tensors; // Shard along rows if (!row_dim.has_value()) { - row_tensors.reserve(rows); - for (int i = 0; i < rows; ++i) { + row_tensors.reserve(mesh_rows_); + for (int i = 0; i < mesh_rows_; ++i) { row_tensors.push_back(tensor); } } else { - row_tensors = experimental::xtensor::chunk(tensor, rows, *row_dim); + row_tensors = experimental::xtensor::chunk(tensor, mesh_rows_, *row_dim); } std::vector tensor_shards; - tensor_shards.reserve(rows * cols); + tensor_shards.reserve(mesh_rows_ * mesh_cols_); // Shard along columns if (!col_dim.has_value()) { for (const auto& t : row_tensors) { - for (int i = 0; i < cols; ++i) { + for (int i = 0; i < mesh_cols_; ++i) { tensor_shards.push_back(t); } } } else { for (const auto& t : row_tensors) { - auto col_chunks = experimental::xtensor::chunk(t, cols, *col_dim); + auto col_chunks = experimental::xtensor::chunk(t, mesh_cols_, *col_dim); tensor_shards.insert(tensor_shards.end(), col_chunks.begin(), col_chunks.end()); } } TT_FATAL( - static_cast(tensor_shards.size()) == rows * cols, + static_cast(tensor_shards.size()) == mesh_rows_ * mesh_cols_, "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " "dimensions. Size: {}, rows: {}, cols: {}", tensor_shards.size(), - rows, - cols); + mesh_rows_, + mesh_cols_); return tensor_shards; } tt::tt_metal::DistributedTensorConfig config() const override { - return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_shape_.num_rows, mesh_shape_.num_cols}}}; + return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_rows_, mesh_cols_}}}; } private: - MeshShape mesh_shape_; + size_t mesh_rows_ = 0; + size_t mesh_cols_ = 0; Shard2dConfig config_; }; @@ -154,18 +154,17 @@ class ConcatMeshToTensor : public MeshToTensor { class Concat2dMeshToTensor : public MeshToTensor { public: - Concat2dMeshToTensor(MeshDevice& mesh_device, const Concat2dConfig& config) : - mesh_shape_(mesh_device.shape()), config_(config) {} + Concat2dMeshToTensor(size_t mesh_rows, size_t mesh_cols, const Concat2dConfig& config) : + mesh_rows_(mesh_rows), mesh_cols_(mesh_cols), config_(config) {} Tensor compose(const std::vector& tensors) const override { - const auto [rows, cols] = mesh_shape_; const auto [row_dim, col_dim] = config_; std::vector row_concatenated; - row_concatenated.reserve(rows); - for (int i = 0; i < rows; ++i) { - auto row_start = tensors.begin() + i * cols; - auto row_end = row_start + cols; + row_concatenated.reserve(mesh_rows_); + for (int i = 0; i < mesh_rows_; ++i) { + auto row_start = tensors.begin() + i * mesh_cols_; + auto row_end = row_start + mesh_cols_; std::vector row_tensors(row_start, row_end); row_concatenated.push_back(experimental::xtensor::concat(row_tensors, col_dim)); } @@ -174,7 +173,8 @@ class Concat2dMeshToTensor : public MeshToTensor { } private: - MeshShape mesh_shape_; + size_t mesh_rows_ = 0; + size_t mesh_cols_ = 0; Concat2dConfig config_; }; From 058891ee93244f48a9903f47c9a494d8b119eccc Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 25 Feb 2025 21:05:28 +0000 Subject: [PATCH 45/76] fix last rebase errors, re-add borrowed support for aggregate_tensor, syntax fixes --- ttnn/cpp/ttnn/distributed/api.cpp | 24 ++ .../ttnn/distributed/distributed_pybind.cpp | 25 +- ttnn/ttnn/distributed/distributed.py | 218 +++++++++--------- ttnn/ttnn/operations/core.py | 7 +- 4 files changed, 152 insertions(+), 122 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 0f6685dc5c3..91679a3168f 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -7,8 +7,10 @@ #include #include +#include "tt-metalium/assert.hpp" #include "tt-metalium/mesh_coord.hpp" #include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_impl.hpp" #include "ttnn/tensor/tensor_utils.hpp" #include "ttnn/distributed/distributed_tensor_config.hpp" #include @@ -93,6 +95,28 @@ Tensor aggregate_as_tensor( } auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), specs}; return Tensor(std::move(storage), reference_shard.get_tensor_spec()); + } else if (storage_type == StorageType::BORROWED) { + std::vector specs; + std::vector host_owned_buffers; + for (const auto& shard : tensor_shards) { + auto buffer = std::get(shard.get_storage()).buffer; + + auto visitor = tt::stl::overloaded{[&shard, &host_owned_buffers](const auto& buffer) -> OwnedBuffer { + using BufferType = std::decay_t; + using ValueType = typename BufferType::value_type; + + std::vector physical_data(buffer.begin(), buffer.end()); + + std::vector logical_data = + tensor_impl::decode_tensor_data(std::move(physical_data), shard.get_tensor_spec()); + + return owned_buffer::create(std::move(logical_data)); + }}; + + host_owned_buffers.push_back(std::visit(visitor, buffer)); + } + auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), specs}; + return Tensor(std::move(storage), reference_shard.get_tensor_spec()); } else { std::vector ordered_device_ids; std::unordered_map specs; diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 53ad59cdb8d..7ed987d1440 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -55,7 +55,7 @@ void py_module_types(py::module& module) { py::class_>(module, "ShardTensorTo2dMesh"); py::class_>(module, "CppConcatMeshToTensor"); py::class_>( - module, "Concat2dMeshToTensor"); + module, "CppConcat2dMeshToTensor"); py::class_(module, "ReplicateTensor"); py::class_(module, "ShardTensor"); @@ -403,7 +403,7 @@ void py_module(py::module& module) { )doc"); auto py_tensor_to_mesh = static_cast>>( - module.attr("CppTensorToMesh")); + module.attr("TensorToMesh")); py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) @@ -492,7 +492,7 @@ void py_module(py::module& module) { auto py_concat_2d_mesh_to_tensor = static_cast>>( - module.attr("Concat2dMeshToTensor")); + module.attr("CppConcat2dMeshToTensor")); py_concat_2d_mesh_to_tensor .def( py::init( @@ -724,7 +724,7 @@ void py_module(py::module& module) { "distribute_tensor", [](const Tensor& tensor, const TensorToMesh& mapper, - std::optional> mesh_device) -> Tensor { + std::optional> mesh_device = std::nullopt) -> Tensor { Tensor cpu_tensor; if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); @@ -736,11 +736,16 @@ void py_module(py::module& module) { }, py::arg("tensor"), py::arg("mapper"), - py::arg("mesh_device")); + py::arg("mesh_device") = py::none()); module.def( "aggregate_tensor", [](const Tensor& tensor, const MeshToTensor& composer) -> Tensor { - Tensor cpu_tensor = from_device(tensor); + Tensor cpu_tensor; + if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); + } else { + cpu_tensor = from_device(tensor); + } return aggregate_tensor(cpu_tensor, composer); }, py::arg("tensor"), @@ -748,7 +753,13 @@ void py_module(py::module& module) { module.def( "aggregate_tensor", [](const std::vector& tensors, const MeshToTensor& composer) -> Tensor { - Tensor cpu_tensor = from_device(aggregate_as_tensor(tensors, AllGatherTensor{})); + Tensor aggregated_tensor = from_device(aggregate_as_tensor(tensors, AllGatherTensor{})); + Tensor cpu_tensor; + if (aggregated_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(aggregated_tensor, true); + } else { + cpu_tensor = from_device(aggregated_tensor); + } return aggregate_tensor(cpu_tensor, composer); }, py::arg("tensor"), diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index 0cc33754036..fa057bd0051 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -238,105 +238,105 @@ def compose(self, tensor: ttnn.Tensor): raise NotImplementedError("Subclasses must implement this method") -# class ShardTensorToMesh(TensorToMesh): -# def __init__(self, mesh_device, dim): -# super().__init__(mesh_device) -# self.shard_dim = dim - -# def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: -# import torch - -# sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) -# return list(sliced_tensors) - -# def config(self): -# return { -# "strategy": "shard", -# "shard_dim": f"{self.shard_dim}", -# } - - -# class ShardTensor2dMesh(TensorToMesh): -# """ -# Shard a tensor across a 2D mesh of devices. -# This class implements a strategy for distributing a tensor across a 2D grid of devices, -# allowing for efficient parallel processing in distributed computing environments. -# """ - -# def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): -# """ -# Initialize the ShardTensor2dMesh. -# Args: -# mesh_device: The target device mesh for distributing the tensor. -# mesh_shape: The shape of the 2D mesh as (rows, cols). -# dims: The dimensions to shard along, specified as (row_dim, col_dim). -# The `dims` tuple determines how the tensor is sharded across the 2D mesh: -# - row_dim: The dimension to shard across mesh rows (or None for replication). -# - col_dim: The dimension to shard across mesh columns (or None for replication). -# Examples: -# 1. dims=(2, 3) for a tensor of shape (A, B, C, D): -# - Shard along dimension 2 (C) across mesh rows -# - Shard along dimension 3 (D) across mesh columns -# 2. dims=(None, 3): -# - Replicate across mesh rows -# - Shard along dimension 3 (D) across mesh columns -# 3. dims=(None, None): -# - Fully replicate the tensor across all devices -# """ -# super().__init__(mesh_device) -# self.mesh_shape: Tuple[int, int] = mesh_shape -# self.dims: Tuple[Optional[int], Optional[int]] = dims - -# mesh_device_rows, mesh_device_cols = self.mesh_device.shape -# if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: -# raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - -# def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: -# """ -# Map the input tensor to a list of sharded tensors. -# Args: -# tensor: The input tensor to be sharded. -# Returns: -# A list of sharded tensors, one for each device in the mesh. -# Raises: -# ValueError: If the number of sharding dimensions is not 2. -# """ -# import torch - -# if len(self.dims) != 2: -# raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - -# rows, cols = self.mesh_shape -# row_dim, col_dim = self.dims - -# # Shard along rows -# row_tensors = ( -# [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) -# ) - -# # Shard along columns -# if col_dim is None: -# return [t.clone() for t in row_tensors for _ in range(cols)] -# tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - -# if len(tensor_shards) != rows * cols: -# raise ValueError( -# f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." -# ) - -# return tensor_shards - -# def config(self) -> Dict[str, str]: -# """ -# Provide the configuration of the sharding strategy. -# Returns: -# A dictionary containing the sharding strategy and dimensions. -# """ -# return { -# "strategy": "shard_2d", -# "mesh_shape_y": str(self.mesh_shape[0]), -# "mesh_shape_x": str(self.mesh_shape[1]), -# } +class ShardTensorToMesh(TensorToMesh): + def __init__(self, mesh_device, dim): + super().__init__(mesh_device) + self.shard_dim = dim + + def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: + import torch + + sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) + return list(sliced_tensors) + + def config(self): + return { + "strategy": "shard", + "shard_dim": f"{self.shard_dim}", + } + + +class ShardTensor2dMesh(TensorToMesh): + """ + Shard a tensor across a 2D mesh of devices. + This class implements a strategy for distributing a tensor across a 2D grid of devices, + allowing for efficient parallel processing in distributed computing environments. + """ + + def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): + """ + Initialize the ShardTensor2dMesh. + Args: + mesh_device: The target device mesh for distributing the tensor. + mesh_shape: The shape of the 2D mesh as (rows, cols). + dims: The dimensions to shard along, specified as (row_dim, col_dim). + The `dims` tuple determines how the tensor is sharded across the 2D mesh: + - row_dim: The dimension to shard across mesh rows (or None for replication). + - col_dim: The dimension to shard across mesh columns (or None for replication). + Examples: + 1. dims=(2, 3) for a tensor of shape (A, B, C, D): + - Shard along dimension 2 (C) across mesh rows + - Shard along dimension 3 (D) across mesh columns + 2. dims=(None, 3): + - Replicate across mesh rows + - Shard along dimension 3 (D) across mesh columns + 3. dims=(None, None): + - Fully replicate the tensor across all devices + """ + super().__init__(mesh_device) + self.mesh_shape: Tuple[int, int] = mesh_shape + self.dims: Tuple[Optional[int], Optional[int]] = dims + + mesh_device_rows, mesh_device_cols = self.mesh_device.shape + if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: + raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") + + def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: + """ + Map the input tensor to a list of sharded tensors. + Args: + tensor: The input tensor to be sharded. + Returns: + A list of sharded tensors, one for each device in the mesh. + Raises: + ValueError: If the number of sharding dimensions is not 2. + """ + import torch + + if len(self.dims) != 2: + raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") + + rows, cols = self.mesh_shape + row_dim, col_dim = self.dims + + # Shard along rows + row_tensors = ( + [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) + ) + + # Shard along columns + if col_dim is None: + return [t.clone() for t in row_tensors for _ in range(cols)] + tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] + + if len(tensor_shards) != rows * cols: + raise ValueError( + f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." + ) + + return tensor_shards + + def config(self) -> Dict[str, str]: + """ + Provide the configuration of the sharding strategy. + Returns: + A dictionary containing the sharding strategy and dimensions. + """ + return { + "strategy": "shard_2d", + "mesh_shape_y": str(self.mesh_shape[0]), + "mesh_shape_x": str(self.mesh_shape[1]), + } class ConcatMesh2dToTensor(MeshToTensor): @@ -397,18 +397,18 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": return torch.cat(row_concatenated, dim=row_dim) -# class ReplicateTensorToMesh(TensorToMesh): -# def __init__(self, mesh_device: MeshDevice): -# super().__init__(mesh_device) +class ReplicateTensorToMesh(TensorToMesh): + def __init__(self, mesh_device: MeshDevice): + super().__init__(mesh_device) -# def map(self, tensor: "torch.Tensor"): -# return [tensor for i in range(self.mesh_device.get_num_devices())] + def map(self, tensor: "torch.Tensor"): + return [tensor for i in range(self.mesh_device.get_num_devices())] -# def config(self): -# return { -# "strategy": "replicate", -# "replication_factor": str(self.mesh_device.get_num_devices()), -# } + def config(self): + return { + "strategy": "replicate", + "replication_factor": str(self.mesh_device.get_num_devices()), + } class ConcatMeshToTensor(MeshToTensor): diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 7b34e969acf..cf2461863e2 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -217,12 +217,7 @@ def from_torch( tensor = ttnn.Tensor(tensor, dtype) if mesh_mapper: - if isinstance(mesh_mapper, ttnn.MeshToTensor): - shards = mesh_mapper.map(ttnn.to_torch(tensor)) - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) - else: - # currently failing - I think this path would be easier to do than calling map and then aggregate unless I add borrowedstorage to aggregate though (non-bfloats end up with that type on tensor creation) - tensor = ttnn.distribute_tensor(tensor, mesh_mapper) + tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) if tile is not None: tensor = ttnn.Tensor(tensor, dtype, {}, tile) From 5dbc31e468eb9c8be7b02f90779f9343c18db863 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 25 Feb 2025 22:00:32 +0000 Subject: [PATCH 46/76] add temporary debugging, re-add copyright header, add memoryconfig for to_device --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 4 ++++ ttnn/cpp/ttnn/distributed/distributed_tensor.cpp | 2 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 16 +++++++++++++++- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 3 +++ 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 7ed987d1440..e1e55317ad0 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -1,3 +1,7 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + #include "ttnn/distributed/distributed_pybind.hpp" #include #include diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index da12dd7a40b..a6bb24f4211 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -188,7 +188,7 @@ Tensor distribute_tensor( std::vector tensors = mapper.map(tensor); Tensor output = aggregate_as_tensor(tensors, mapper.config()); if (mesh_device.has_value()) { - return output.to_device(&(mesh_device->get())); + return output.to_device(&(mesh_device->get()), output.memory_config()); } return output; } diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 30bab3457b6..bec4602668f 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -763,7 +763,21 @@ Tensor Tensor::to_device(IDevice* target_device, const MemoryConfig& mem_config, Tensor Tensor::to_device(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config, QueueId cq_id) const { std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); - return tensor_ops::tensor_to_device(*this, workers_to_use, mem_config, cq_id); + + // TODO: remove + std::cout << "debugprint" << std::endl; + for (auto worker : workers_to_use) { + std::cout << worker->id() << std::endl; + } + std::cout << "configs" << std::endl; + std::cout << mem_config.is_sharded() << std::endl; + std::cout << mem_config.is_dram() << std::endl; + + std::cout << "tensorinfo" << std::endl; + std::cout << this->is_allocated() << std::endl; + + // return tensor_ops::tensor_to_device(*this, workers_to_use, mem_config, cq_id); + return *this; } Tensor Tensor::to_device(const std::vector& workers, const MemoryConfig& mem_config, QueueId cq_id) const { diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 05e51fc4fba..99b9548d84d 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -29,6 +29,9 @@ namespace tt::tt_metal::tensor_ops { Tensor tensor_to_device( const Tensor& input_tensor, IDevice* target_device, const MemoryConfig& mem_config, QueueId cq_id) { + // TODO: remove + std::cout << "debugprint2" << std::endl; + ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to_device", input_tensor, target_device, mem_config); // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. From 3257aaaf73209f5ca2333ddedd26103da83b36d3 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 26 Feb 2025 14:22:26 +0000 Subject: [PATCH 47/76] fix spec error --- ttnn/cpp/ttnn/distributed/api.cpp | 1 + ttnn/cpp/ttnn/tensor/tensor.cpp | 15 +-------------- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 1 + 3 files changed, 3 insertions(+), 14 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 91679a3168f..421e9ca3540 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -100,6 +100,7 @@ Tensor aggregate_as_tensor( std::vector host_owned_buffers; for (const auto& shard : tensor_shards) { auto buffer = std::get(shard.get_storage()).buffer; + specs.push_back(shard.get_tensor_spec()); auto visitor = tt::stl::overloaded{[&shard, &host_owned_buffers](const auto& buffer) -> OwnedBuffer { using BufferType = std::decay_t; diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index bec4602668f..8fc1e6e3de7 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -764,20 +764,7 @@ Tensor Tensor::to_device(IDevice* target_device, const MemoryConfig& mem_config, Tensor Tensor::to_device(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config, QueueId cq_id) const { std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); - // TODO: remove - std::cout << "debugprint" << std::endl; - for (auto worker : workers_to_use) { - std::cout << worker->id() << std::endl; - } - std::cout << "configs" << std::endl; - std::cout << mem_config.is_sharded() << std::endl; - std::cout << mem_config.is_dram() << std::endl; - - std::cout << "tensorinfo" << std::endl; - std::cout << this->is_allocated() << std::endl; - - // return tensor_ops::tensor_to_device(*this, workers_to_use, mem_config, cq_id); - return *this; + return tensor_ops::tensor_to_device(*this, workers_to_use, mem_config, cq_id); } Tensor Tensor::to_device(const std::vector& workers, const MemoryConfig& mem_config, QueueId cq_id) const { diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 99b9548d84d..65ddf493197 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -197,6 +197,7 @@ Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, distri host_storage != nullptr) { distributed_config = host_storage->strategy; } + // TODO: remove, check the tilize tomorrow Tensor tensor_modified_layout = Tensor(workers.size(), distributed_config); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { From 45c56e674d8d46da0f79feab4da9bf4dbcfa9ddf Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 26 Feb 2025 19:33:14 +0000 Subject: [PATCH 48/76] debugging prints for tilize, add switch back and move all classes back into cpp --- .../distributed/test_distributed_tensor.py | 27 +++- .../ttnn/distributed/distributed_pybind.cpp | 120 +++++++--------- .../ttnn/distributed/distributed_tensor.hpp | 131 ------------------ .../core/to_layout/to_layout_op.cpp | 2 + .../data_movement/tilize/tilize.cpp | 1 + ttnn/ttnn/__init__.py | 9 +- ttnn/ttnn/distributed/__init__.py | 8 +- ttnn/ttnn/operations/core.py | 22 +-- 8 files changed, 98 insertions(+), 222 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 260c6705fa2..0bd58d849dd 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -17,7 +17,7 @@ ], indirect=True, ) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.int32, ttnn.uint8, ttnn.uint32]) def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) @@ -28,7 +28,7 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper = mapper, + mesh_mapper=mapper, device=mesh_device, ) @@ -38,6 +38,29 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): logger.info(f"PCC value: {out_pcc}") assert out_pass + +# # @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.int32, ttnn.uint8, ttnn.uint32]) +# def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): +# torch.manual_seed(1234) + +# mapper = ttnn.CppReplicateTensorToMesh(mesh_device) + +# torch_tensor = torch.randn(1, 1, 32, 256) +# replicated_tensors = ttnn.from_torch( +# torch_tensor, +# dtype=dtype, +# layout=ttnn.TILE_LAYOUT, +# mesh_mapper = mapper, +# device=mesh_device, +# ) + +# out_tensors = ttnn.get_device_tensors(replicated_tensors) + +# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) +# logger.info(f"PCC value: {out_pcc}") +# assert out_pass + # @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) # def test_direct_shard_to_tensor_mesh(mesh_device, dtype): # torch.manual_seed(1234) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index e1e55317ad0..35ef5eb1c99 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -49,14 +49,22 @@ struct ConcreteMeshToTensor : MeshToTensor { } }; +// unused empty implementations to satisfy pybind's desire for unique objects +class ReplicateTensorToMesh : public TensorToMesh {}; +class ShardTensorToMesh : public TensorToMesh {}; +class ShardTensorTo2dMesh : public TensorToMesh {}; +class ConcatMeshToTensor : public MeshToTensor {}; +class Concat2dMeshToTensor : public MeshToTensor {}; + void py_module_types(py::module& module) { py::class_>(module, "CppMeshToTensor"); - py::class_>(module, "TensorToMesh"); + py::class_>(module, "CppTensorToMesh"); py::class_>( - module, "ReplicateTensorToMesh"); - py::class_>(module, "ShardTensorToMesh"); - py::class_>(module, "ShardTensorTo2dMesh"); + module, "CppReplicateTensorToMesh"); + py::class_>(module, "CppShardTensorToMesh"); + py::class_>( + module, "CppShardTensorTo2dMesh"); py::class_>(module, "CppConcatMeshToTensor"); py::class_>( module, "CppConcat2dMeshToTensor"); @@ -407,128 +415,96 @@ void py_module(py::module& module) { )doc"); auto py_tensor_to_mesh = static_cast>>( - module.attr("TensorToMesh")); + module.attr("CppTensorToMesh")); py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); auto py_replicate_tensor_to_mesh = - static_cast>>( - module.attr("ReplicateTensorToMesh")); + static_cast>>(module.attr("CppReplicateTensorToMesh")); py_replicate_tensor_to_mesh .def( - py::init([](MeshDevice& mesh_device) -> std::unique_ptr { - return std::make_unique(ReplicateTensorToMesh(mesh_device.num_devices())); + py::init([](MeshDevice& mesh_device) -> std::unique_ptr { + return replicate_tensor_to_mesh_mapper(mesh_device); }), py::arg("mesh_device")) .def( - "map", - [](const ReplicateTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, - py::arg("tensor")) - .def("config", &ReplicateTensorToMesh::config); - auto py_shard_tensor_to_mesh = static_cast>>( - module.attr("ShardTensorToMesh")); + "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) + .def("config", &TensorToMesh::config); + auto py_shard_tensor_to_mesh = + static_cast>>(module.attr("CppShardTensorToMesh")); py_shard_tensor_to_mesh .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(ShardTensorToMesh(mesh_device.num_devices(), dim)); + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return shard_tensor_to_mesh_mapper(mesh_device, dim); }), py::arg("mesh_device"), py::arg("dim")) .def( - "map", - [](const ShardTensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, - py::arg("tensor")) - .def("config", &ShardTensorToMesh::config); + "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) + .def("config", &TensorToMesh::config); auto py_shard_tensor_to_2d_mesh = - static_cast>>( - module.attr("ShardTensorTo2dMesh")); + static_cast>>(module.attr("CppShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( [](MeshDevice& mesh_device, const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { + const std::tuple dims) -> std::unique_ptr { int mesh_rows = std::get<0>(mesh_shape); int mesh_cols = std::get<1>(mesh_shape); int config_rows = std::get<0>(dims); int config_cols = std::get<1>(dims); - TT_FATAL( - config_rows || config_cols, - "Sharding a tensor to 2D mesh requires at least one dimension to shard"); - TT_FATAL( - mesh_rows <= mesh_device.shape()[0] && // - mesh_cols <= mesh_device.shape()[1], - "Device mesh shape does not match the provided mesh shape."); - - return std::make_unique( - mesh_rows, - mesh_cols, + return shard_tensor_to_2d_mesh_mapper( + mesh_device, + MeshShape(mesh_rows, mesh_cols), Shard2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }), py::arg("mesh_device"), py::arg("mesh_shape"), py::arg("dims")) .def( - "map", - [](const ShardTensorTo2dMesh& self, const Tensor& tensor) { return self.map(tensor); }, - py::arg("tensor")) - .def("config", &ShardTensorTo2dMesh::config); + "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) + .def("config", &TensorToMesh::config); auto py_mesh_to_tensor = static_cast>>( module.attr("CppMeshToTensor")); py_mesh_to_tensor .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("compose", &MeshToTensor::compose); - auto py_concat_mesh_to_tensor = static_cast>>( - module.attr("CppConcatMeshToTensor")); + auto py_concat_mesh_to_tensor = + static_cast>>(module.attr("CppConcatMeshToTensor")); py_concat_mesh_to_tensor .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return std::make_unique(dim); + py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { + return concat_mesh_to_tensor_composer(dim); }), py::arg("mesh_device"), py::arg("dim")) .def( "compose", - [](const ConcatMeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, + [](const MeshToTensor& self, const std::vector& tensors) { return self.compose(tensors); }, py::arg("tensors")); auto py_concat_2d_mesh_to_tensor = - static_cast>>( - module.attr("CppConcat2dMeshToTensor")); + static_cast>>(module.attr("CppConcat2dMeshToTensor")); py_concat_2d_mesh_to_tensor .def( - py::init( - [](MeshDevice& mesh_device, - const std::tuple mesh_shape, - const std::tuple dims) -> std::unique_ptr { - int row_dim = std::get<0>(dims); - int col_dim = std::get<1>(dims); - TT_FATAL( - std::get<0>(mesh_shape) <= mesh_device.shape()[0] && // - std::get<1>(mesh_shape) <= mesh_device.shape()[1], - "Device mesh shape does not match the provided mesh shape."); - - TT_FATAL( - row_dim != col_dim, - "Dimensions in 'dims' must be different; got row_dim: {}, col_dim: {}", - row_dim, - col_dim); - return std::make_unique( - row_dim, - col_dim, - Concat2dConfig{ - .row_dim = row_dim, - .col_dim = col_dim, - }); - }), + py::init([](MeshDevice& mesh_device, const std::tuple dims) -> std::unique_ptr { + int row_dim = std::get<0>(dims); + int col_dim = std::get<1>(dims); + return concat_2d_mesh_to_tensor_composer( + mesh_device, + Concat2dConfig{ + .row_dim = row_dim, + .col_dim = col_dim, + }); + }), py::arg("mesh_device"), - py::arg("Mesh_shape"), py::arg("dims")) .def( "compose", - [](const Concat2dMeshToTensor& self, const std::vector& tensors) -> Tensor { + [](const MeshToTensor& self, const std::vector& tensors) -> Tensor { return self.compose(tensors); }, py::arg("tensors")); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index d7cf747492c..edfd47a80b8 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -47,137 +47,6 @@ struct Concat2dConfig { int col_dim = -1; }; -class ReplicateTensorToMesh : public TensorToMesh { -public: - ReplicateTensorToMesh(size_t num_devices) : num_devices_(num_devices) {} - - std::vector map(const Tensor& tensor) const override { - std::vector tensors; - tensors.reserve(num_devices_); - std::fill_n(std::back_inserter(tensors), num_devices_, tensor); - return tensors; - } - - tt::tt_metal::DistributedTensorConfig config() const override { - return tt::tt_metal::DistributedTensorConfig{ReplicateTensor{num_devices_}}; - } - -private: - size_t num_devices_ = 0; -}; - -class ShardTensorToMesh : public TensorToMesh { -public: - ShardTensorToMesh(size_t num_devices, int dim) : num_devices_(num_devices), shard_dim_(dim) {} - - std::vector map(const Tensor& tensor) const override { - return experimental::xtensor::chunk(tensor, num_devices_, shard_dim_); - } - - tt::tt_metal::DistributedTensorConfig config() const override { - return tt::tt_metal::DistributedTensorConfig{ShardTensor{shard_dim_}}; - } - -private: - size_t num_devices_ = 0; - int shard_dim_ = -1; -}; - -class ShardTensorTo2dMesh : public TensorToMesh { -public: - ShardTensorTo2dMesh(size_t mesh_rows, size_t mesh_cols, const Shard2dConfig& config) : - mesh_rows_(mesh_rows), mesh_cols_(mesh_cols), config_(config) {} - - std::vector map(const Tensor& tensor) const override { - const auto [row_dim, col_dim] = config_; - - std::vector row_tensors; - - // Shard along rows - if (!row_dim.has_value()) { - row_tensors.reserve(mesh_rows_); - for (int i = 0; i < mesh_rows_; ++i) { - row_tensors.push_back(tensor); - } - } else { - row_tensors = experimental::xtensor::chunk(tensor, mesh_rows_, *row_dim); - } - - std::vector tensor_shards; - tensor_shards.reserve(mesh_rows_ * mesh_cols_); - // Shard along columns - if (!col_dim.has_value()) { - for (const auto& t : row_tensors) { - for (int i = 0; i < mesh_cols_; ++i) { - tensor_shards.push_back(t); - } - } - } else { - for (const auto& t : row_tensors) { - auto col_chunks = experimental::xtensor::chunk(t, mesh_cols_, *col_dim); - tensor_shards.insert(tensor_shards.end(), col_chunks.begin(), col_chunks.end()); - } - } - - TT_FATAL( - static_cast(tensor_shards.size()) == mesh_rows_ * mesh_cols_, - "ShardTensorTo2dMesh: Sharding failed. Number of shards should match the product of the mesh " - "dimensions. Size: {}, rows: {}, cols: {}", - tensor_shards.size(), - mesh_rows_, - mesh_cols_); - - return tensor_shards; - } - - tt::tt_metal::DistributedTensorConfig config() const override { - return DistributedTensorConfig{ShardTensor2D{ShardMesh{mesh_rows_, mesh_cols_}}}; - } - -private: - size_t mesh_rows_ = 0; - size_t mesh_cols_ = 0; - Shard2dConfig config_; -}; - -class ConcatMeshToTensor : public MeshToTensor { -public: - ConcatMeshToTensor(int dim) : concat_dim_(dim) {} - - Tensor compose(const std::vector& tensors) const override { - return experimental::xtensor::concat(tensors, concat_dim_); - } - -private: - int concat_dim_ = -1; -}; - -class Concat2dMeshToTensor : public MeshToTensor { -public: - Concat2dMeshToTensor(size_t mesh_rows, size_t mesh_cols, const Concat2dConfig& config) : - mesh_rows_(mesh_rows), mesh_cols_(mesh_cols), config_(config) {} - - Tensor compose(const std::vector& tensors) const override { - const auto [row_dim, col_dim] = config_; - - std::vector row_concatenated; - row_concatenated.reserve(mesh_rows_); - for (int i = 0; i < mesh_rows_; ++i) { - auto row_start = tensors.begin() + i * mesh_cols_; - auto row_end = row_start + mesh_cols_; - std::vector row_tensors(row_start, row_end); - row_concatenated.push_back(experimental::xtensor::concat(row_tensors, col_dim)); - } - - return experimental::xtensor::concat(row_concatenated, row_dim); - } - -private: - size_t mesh_rows_ = 0; - size_t mesh_cols_ = 0; - Concat2dConfig config_; -}; - // Creates a mapper that replicates a tensor across all devices. std::unique_ptr replicate_tensor_to_mesh_mapper(MeshDevice& mesh_device); diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp index 6429d55226b..db13774bbb9 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.cpp @@ -14,6 +14,7 @@ #include #include "cpp/ttnn/operations/experimental/reshape/view.hpp" #include "ttnn/operations/core/core.hpp" +#include "ttnn/tensor/types.hpp" #include "ttnn/types.hpp" namespace ttnn { @@ -104,6 +105,7 @@ Tensor to_layout_impl( TT_ASSERT(not dtype.has_value(), "dtype cannot be specified when converting to ROW_MAJOR_LAYOUT!"); return ttnn::untilize(tensor, output_memory_config, use_multicore_untilize); } else if (layout == ttnn::TILE_LAYOUT) { + std::cout << "tilizing1" << std::endl; if (tensor.is_sharded()) { const auto tensor_tile = tensor.get_tensor_spec().tile(); uint32_t tile_height = tensor_tile.get_height(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp index e566e554d39..8e9f0c426c7 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/tilize/tilize.cpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "tilize.hpp" +#include #include "device/tilize_op.hpp" #include "ttnn/common/queue_id.hpp" diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index fba567adc09..dce198ccc88 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -94,14 +94,13 @@ def manage_config(name, value): logger.debug(f"Restored ttnn.CONFIG.{name} to {original_value}") -# apparently the names need to match the types exactly for pybind function arguments, I think a pure python alias would face the same issue from ttnn._ttnn.multi_device import ( MeshDevice, CppMeshToTensor, - TensorToMesh, - ReplicateTensorToMesh, - ShardTensorToMesh, - ShardTensorTo2dMesh, + CppTensorToMesh, + CppReplicateTensorToMesh, + CppShardTensorToMesh, + CppShardTensorTo2dMesh, CppConcatMeshToTensor, CppConcat2dMeshToTensor, ReplicateTensor, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 32d1109786c..4901c6ae8cb 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -6,10 +6,10 @@ from .distributed import ( MeshDevice, DispatchCoreType, - # TensorToMesh, - # ShardTensorToMesh, - # ShardTensor2dMesh, - # ReplicateTensorToMesh, + TensorToMesh, + ShardTensorToMesh, + ShardTensor2dMesh, + ReplicateTensorToMesh, MeshToTensor, ConcatMeshToTensor, ConcatMesh2dToTensor, diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index cf2461863e2..359c0669d32 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -156,7 +156,7 @@ def from_torch( layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, - mesh_mapper: Optional[ttnn.TensorToMesh] = None, + mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> ttnn.Tensor: """ @@ -208,7 +208,7 @@ def from_torch( if dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b: if layout != ttnn.TILE_LAYOUT: raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!") - # Tilize tensor + # Tilize tensor, TODO: this is incredibly non-performant tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value, mesh_mapper=None) logical_shape = tensor.shape padded_shape = tensor.padded_shape @@ -217,7 +217,11 @@ def from_torch( tensor = ttnn.Tensor(tensor, dtype) if mesh_mapper: - tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) + if isinstance(mesh_mapper, ttnn.CppTensorToMesh): + tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) + else: + shards = mesh_mapper.map(tensor) + tensor = ttnn.aggregate_as_tensor(shards) if tile is not None: tensor = ttnn.Tensor(tensor, dtype, {}, tile) @@ -225,7 +229,7 @@ def from_torch( if layout is not None and not (dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b): if pad_value is not None: tensor = tensor.pad_to_tile(pad_value) - tensor = ttnn.to_layout(tensor, layout, device=device) + tensor = ttnn.to_layout(tensor, layout) if device is not None: if memory_config is None: @@ -519,7 +523,7 @@ def as_tensor( memory_config: Optional[ttnn.MemoryConfig] = None, cache_file_name: Optional[Union[str, pathlib.Path]] = None, preprocess: Optional[Callable[[ttnn.Tensor], ttnn.Tensor]] = None, - mesh_mapper: Optional[ttnn.TensorToMesh] = None, + mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh] = None, use_device_tilizer: bool = False, ) -> ttnn.Tensor: """ @@ -571,7 +575,7 @@ def torch_to_ttnn( layout: Optional[ttnn.Layout], device: Optional[ttnn.Device], memory_config: Optional[ttnn.MemoryConfig], - mesh_mapper: Optional[ttnn.TensorToMesh], + mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh], ): if preprocess: tensor = preprocess(tensor) @@ -604,7 +608,7 @@ def from_torch_and_dump( dtype: Optional[ttnn.DataType], layout: Optional[ttnn.Layout], cache_file_name: str, - mesh_mapper: Optional[ttnn.TensorToMesh], + mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh], ): tensor = torch_to_ttnn(tensor, dtype, layout, device, memory_config, mesh_mapper) logger.debug( @@ -615,7 +619,9 @@ def from_torch_and_dump( ttnn._ttnn.tensor.dump_tensor(cache_file_name, tensor, distributed_config) return tensor - if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh): + if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh) or isinstance( + mesh_mapper, ttnn.CppReplicateTensorToMesh + ): storage_type = f"_multi_device" if mesh_mapper else "" elif mesh_mapper: storage_type = f"_multi_device_{device.get_num_devices()}" From 06eacacab47d12af63ca70ab30e50b0063fe64e8 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 27 Feb 2025 20:13:14 +0000 Subject: [PATCH 49/76] fix from_torch device, typing errors --- .../distributed/test_distributed_tensor.py | 471 +++++++++--------- .../ttnn/distributed/distributed_pybind.cpp | 3 +- .../ttnn/distributed/distributed_tensor.hpp | 6 - ttnn/ttnn/operations/core.py | 28 +- 4 files changed, 246 insertions(+), 262 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 0bd58d849dd..dae94f12ddb 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -17,11 +17,11 @@ ], indirect=True, ) -@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.int32, ttnn.uint8, ttnn.uint32]) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - mapper = ttnn.ReplicateTensorToMesh(mesh_device) + mapper = ttnn.CppReplicateTensorToMesh(mesh_device) torch_tensor = torch.randn(1, 1, 32, 256) replicated_tensors = ttnn.from_torch( @@ -39,282 +39,261 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): assert out_pass -# # @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.int32, ttnn.uint8, ttnn.uint32]) -# def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): -# torch.manual_seed(1234) - -# mapper = ttnn.CppReplicateTensorToMesh(mesh_device) - -# torch_tensor = torch.randn(1, 1, 32, 256) -# replicated_tensors = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# mesh_mapper = mapper, -# device=mesh_device, -# ) - -# out_tensors = ttnn.get_device_tensors(replicated_tensors) - -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass - -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_direct_shard_to_tensor_mesh(mesh_device, dtype): -# torch.manual_seed(1234) - -# mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) - -# torch_tensor = torch.randn(1, 1, 32, 256) -# sharded_tensor = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# mesh_mapper = mapper, -# device=mesh_device, -# ) - -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass - -# @pytest.mark.parametrize( -# "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] -# ) -# @pytest.mark.parametrize( -# "M, K, N", -# [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], -# ) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): -# torch.manual_seed(1234) - -# torch_tensor = torch.randn(1, 1, M, K) -# core_grid = ttnn.CoreGrid(y=1, x=8) - -# # If K < N it's FF1-like test case, else FF2-like test case -# shard_dim = (0, 3) if K < N else (3, 0) - -# K = K // mesh_shape[1] if K < N else K // mesh_shape[0] -# N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - -# sharded_mem_config = ttnn.create_sharded_memory_config( -# shape=(M // core_grid.y, K // core_grid.x), -# core_grid=core_grid, -# strategy=ttnn.ShardStrategy.WIDTH, -# orientation=ttnn.ShardOrientation.ROW_MAJOR, -# use_height_and_width_as_shard_shape=True, -# ) - -# mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - -# sharded_tensor = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, -# mesh_mapper = mapper, -# device=mesh_device, -# ) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_shard_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) + + torch_tensor = torch.randn(1, 1, 32, 256) + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=mapper, + device=mesh_device, + ) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + sharded_mem_config = ttnn.create_sharded_memory_config( + shape=(M // core_grid.y, K // core_grid.x), + core_grid=core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + mapper = ttnn.CppShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=mapper, + device=mesh_device, + ) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize( + "mesh_device", + [ + 32, + ], + indirect=True, +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_replicate_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, 32, 256) + to_repl = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + + mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) + out_tensors = ttnn.get_device_tensors(replicated_tensors) + + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_shard_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass + torch_tensor = torch.randn(1, 1, 32, 256) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) -# @pytest.mark.parametrize( -# "mesh_device", -# [ -# 32, -# ], -# indirect=True, -# ) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_replicate_to_tensor_mesh(mesh_device, dtype): -# torch.manual_seed(1234) - -# torch_tensor = torch.randn(1, 1, 32, 256) -# to_repl = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# device=mesh_device, -# ) - -# mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) -# replicated_tensors = ttnn.distribute_tensor(to_repl, mapper, mesh_device) -# out_tensors = ttnn.get_device_tensors(replicated_tensors) - -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass - - -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_shard_to_tensor_mesh(mesh_device, dtype): -# torch.manual_seed(1234) - -# torch_tensor = torch.randn(1, 1, 32, 256) -# to_shard = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# device=mesh_device, -# ) - -# mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - -# shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - -# out_tensor = ttnn.aggregate_as_tensor(shards) - -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + out_tensor = ttnn.aggregate_as_tensor(shards) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_concat_to_tensor(mesh_device, dtype): -# torch.manual_seed(1234) - -# torch_tensor = torch.randn(1, 1, 32, 256) -# to_shard = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# device=mesh_device, -# ) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass -# mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) -# composer = ttnn.concat_mesh_to_tensor_composer(dim=3) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) -# out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) + torch_tensor = torch.randn(1, 1, 32, 256) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_concat_slice_to_tensor(mesh_device, dtype): -# torch.manual_seed(1234) - -# torch_tensor = torch.randn(1, 1, 32, 256) -# to_shard = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# device=mesh_device, -# ) - -# mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) - -# composer = ttnn.concat_mesh_to_tensor_composer(dim=3) - -# sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) -# shards = ttnn.get_device_tensors(sharded_tensor) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass -# out_tensor = ttnn.aggregate_tensor(shards, composer) -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat_slice_to_tensor(mesh_device, dtype): + torch.manual_seed(1234) + torch_tensor = torch.randn(1, 1, 32, 256) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + device=mesh_device, + ) -# @pytest.mark.parametrize( -# "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] -# ) -# @pytest.mark.parametrize( -# "M, K, N", -# [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], -# ) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): -# torch.manual_seed(1234) + mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) -# torch_tensor = torch.randn(1, 1, M, K) -# core_grid = ttnn.CoreGrid(y=1, x=8) + composer = ttnn.concat_mesh_to_tensor_composer(dim=3) -# # If K < N it's FF1-like test case, else FF2-like test case -# shard_dim = (0, 3) if K < N else (3, 0) + sharded_tensor = ttnn.distribute_tensor(to_shard, mapper, mesh_device) -# K = K // mesh_shape[1] if K < N else K // mesh_shape[0] -# N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + shards = ttnn.get_device_tensors(sharded_tensor) -# sharded_mem_config = ttnn.create_sharded_memory_config( -# shape=(M // core_grid.y, K // core_grid.x), -# core_grid=core_grid, -# strategy=ttnn.ShardStrategy.WIDTH, -# orientation=ttnn.ShardOrientation.ROW_MAJOR, -# use_height_and_width_as_shard_shape=True, -# ) + out_tensor = ttnn.aggregate_tensor(shards, composer) -# to_shard = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, -# device=mesh_device, -# ) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass -# mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) -# shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + sharded_mem_config = ttnn.create_sharded_memory_config( + shape=(M // core_grid.y, K // core_grid.x), + core_grid=core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + ) + + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) + + ttnn.aggregate_as_tensor(shards) -# ttnn.aggregate_as_tensor(shards) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) -# @pytest.mark.parametrize( -# "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] -# ) -# @pytest.mark.parametrize( -# "M, K, N", -# [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], -# ) -# @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) -# def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): -# torch.manual_seed(1234) + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) -# torch_tensor = torch.randn(1, 1, M, K) -# core_grid = ttnn.CoreGrid(y=1, x=8) + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + concat_dim = (3, 1) if K < N else (1, 3) -# # If K < N it's FF1-like test case, else FF2-like test case -# shard_dim = (0, 3) if K < N else (3, 0) -# concat_dim = (3, 1) if K < N else (1, 3) + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] -# K = K // mesh_shape[1] if K < N else K // mesh_shape[0] -# N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + sharded_mem_config = ttnn.create_sharded_memory_config( + shape=(M // core_grid.y, K // core_grid.x), + core_grid=core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) -# sharded_mem_config = ttnn.create_sharded_memory_config( -# shape=(M // core_grid.y, K // core_grid.x), -# core_grid=core_grid, -# strategy=ttnn.ShardStrategy.WIDTH, -# orientation=ttnn.ShardOrientation.ROW_MAJOR, -# use_height_and_width_as_shard_shape=True, -# ) + to_shard = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + ) -# to_shard = ttnn.from_torch( -# torch_tensor, -# dtype=dtype, -# layout=ttnn.TILE_LAYOUT, -# memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, -# device=mesh_device, -# ) + mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) -# mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) -# composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) -# out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - -# out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) -# logger.info(f"PCC value: {out_pcc}") -# assert out_pass + out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 35ef5eb1c99..954df686869 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -706,12 +706,13 @@ void py_module(py::module& module) { const TensorToMesh& mapper, std::optional> mesh_device = std::nullopt) -> Tensor { Tensor cpu_tensor; + printf("printingc %d\n", is_multi_device_tensor(cpu_tensor) || is_device_tensor(tensor)); if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); } else { cpu_tensor = from_device(tensor); } - + printf("printingc %d\n", is_multi_device_tensor(cpu_tensor) || is_device_tensor(tensor)); return distribute_tensor(cpu_tensor, mapper, mesh_device); }, py::arg("tensor"), diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index edfd47a80b8..d7705544e2a 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -13,12 +13,6 @@ #include "ttnn/tensor/xtensor/partition.hpp" #include #include -#include "ttnn/distributed/api.hpp" -#include "ttnn/distributed/distributed_tensor_config.hpp" -#include "ttnn/distributed/types.hpp" -#include "ttnn/tensor/xtensor/partition.hpp" -#include -#include namespace ttnn::distributed { diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 359c0669d32..b0ce21fc37c 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -156,7 +156,7 @@ def from_torch( layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, - mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh] = None, + mesh_mapper: Optional[Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh]] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> ttnn.Tensor: """ @@ -208,7 +208,7 @@ def from_torch( if dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b: if layout != ttnn.TILE_LAYOUT: raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!") - # Tilize tensor, TODO: this is incredibly non-performant + # Tilize tensor, TODO: this is incredibly non-performant when done on host tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value, mesh_mapper=None) logical_shape = tensor.shape padded_shape = tensor.padded_shape @@ -219,22 +219,32 @@ def from_torch( if mesh_mapper: if isinstance(mesh_mapper, ttnn.CppTensorToMesh): tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) + if tile is not None: + tensor = ttnn.Tensor(ttnn.to_torch(tensor), dtype, {}, tile) else: - shards = mesh_mapper.map(tensor) - tensor = ttnn.aggregate_as_tensor(shards) - - if tile is not None: - tensor = ttnn.Tensor(tensor, dtype, {}, tile) + shards = mesh_mapper.map(ttnn.to_torch(tensor)) + if tile is not None: + tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config(), tile) + else: + tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) + else: + if tile is not None: + tensor = ttnn.Tensor(ttnn.to_torch(tensor), dtype, {}, tile) if layout is not None and not (dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b): if pad_value is not None: tensor = tensor.pad_to_tile(pad_value) - tensor = ttnn.to_layout(tensor, layout) + if ttnn.is_tensor_storage_on_device(tensor): + # TODO: support tilizing non bfloat/float types on device tensors making this expensive conversion unnecessary + tensor = ttnn.from_device(tensor, cq_id=cq_id) + tensor = ttnn.to_layout(tensor, layout, device=device) if device is not None: if memory_config is None: memory_config = ttnn.DRAM_MEMORY_CONFIG - tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id) + # Handle sharding case which will have already output to a multidevice + if not ttnn.is_tensor_storage_on_device(tensor): + tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id) if logical_shape is not None and logical_shape != tensor.shape and mesh_mapper is None: tensor = ttnn.reshape(tensor, logical_shape, padded_shape) From 7c83bcb0ee98b65f465105b8b2594234affd1bfd Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 28 Feb 2025 06:06:14 +0000 Subject: [PATCH 50/76] remove debug prints --- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 954df686869..2ce883fe23a 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -706,13 +706,11 @@ void py_module(py::module& module) { const TensorToMesh& mapper, std::optional> mesh_device = std::nullopt) -> Tensor { Tensor cpu_tensor; - printf("printingc %d\n", is_multi_device_tensor(cpu_tensor) || is_device_tensor(tensor)); if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); } else { cpu_tensor = from_device(tensor); } - printf("printingc %d\n", is_multi_device_tensor(cpu_tensor) || is_device_tensor(tensor)); return distribute_tensor(cpu_tensor, mapper, mesh_device); }, py::arg("tensor"), From e2c189fb0c3d5ab3214161f21f6fcae9a89b42eb Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Fri, 28 Feb 2025 08:22:36 +0000 Subject: [PATCH 51/76] reformat tilize, fix golden comparisons in testing, add direct_concat tests, simplify parallel signatures --- .../distributed/test_distributed_tensor.py | 92 +++++++++++++++++-- .../ttnn/distributed/distributed_pybind.cpp | 6 +- ttnn/ttnn/operations/core.py | 23 +++-- 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index dae94f12ddb..1bfa519c5e4 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -34,7 +34,7 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): out_tensors = ttnn.get_device_tensors(replicated_tensors) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensors[0]), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensors[0]), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -54,7 +54,7 @@ def test_direct_shard_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(sharded_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -98,7 +98,83 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(sharded_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(ttnn.to_torch(torch_tensor, sharded_tensor), pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + assert out_pass + + +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_concat_to_tensor_mesh(mesh_device, dtype): + torch.manual_seed(1234) + + mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) + + torch_tensor = torch.randn(1, 1, 32, 256) + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=mapper, + device=mesh_device, + ) + + composer = ttnn.CppConcatMeshToTensor(mesh_device, dim=3) + + concat_tensor = ttnn.to_torch(sharded_tensor, mesh_composer=composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, concat_tensor, pcc=0.99) + logger.info(f"PCC value: {out_pcc}") + print("attempt") + print(torch_tensor) + print(concat_tensor) + assert out_pass + + +@pytest.mark.parametrize( + "mesh_shape, mesh_device", [pytest.param((8, 4), (8, 4), id="8x4_grid")], indirect=["mesh_device"] +) +@pytest.mark.parametrize( + "M, K, N", + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], +) +@pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) +def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): + torch.manual_seed(1234) + + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) + + # If K < N it's FF1-like test case, else FF2-like test case + shard_dim = (0, 3) if K < N else (3, 0) + concat_dim = (3, 1) if K < N else (1, 3) + + K = K // mesh_shape[1] if K < N else K // mesh_shape[0] + N = N // mesh_shape[0] if K < N else N // mesh_shape[1] + + sharded_mem_config = ttnn.create_sharded_memory_config( + shape=(M // core_grid.y, K // core_grid.x), + core_grid=core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + mapper = ttnn.CppShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + + sharded_tensor = ttnn.from_torch( + torch_tensor, + dtype=dtype, + layout=ttnn.TILE_LAYOUT, + memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + mesh_mapper=mapper, + device=mesh_device, + ) + + composer = ttnn.CppConcat2dMeshToTensor(mesh_device, mesh_shape, dim=concat_dim) + + concat_tensor = ttnn.to_torch(sharded_tensor, mesh_composer=composer) + + out_pass, out_pcc = comp_pcc(torch_tensor, concat_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -149,7 +225,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): out_tensor = ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -172,7 +248,7 @@ def test_concat_to_tensor(mesh_device, dtype): out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -199,7 +275,7 @@ def test_concat_slice_to_tensor(mesh_device, dtype): out_tensor = ttnn.aggregate_tensor(shards, composer) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -246,7 +322,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(shards), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(shards), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -294,6 +370,6 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(out_tensor), torch_tensor, pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, tnn.to_torch(out_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 2ce883fe23a..ba067a8f610 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -476,10 +476,7 @@ void py_module(py::module& module) { static_cast>>(module.attr("CppConcatMeshToTensor")); py_concat_mesh_to_tensor .def( - py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { - return concat_mesh_to_tensor_composer(dim); - }), - py::arg("mesh_device"), + py::init([](int dim) -> std::unique_ptr { return concat_mesh_to_tensor_composer(dim); }), py::arg("dim")) .def( "compose", @@ -705,6 +702,7 @@ void py_module(py::module& module) { [](const Tensor& tensor, const TensorToMesh& mapper, std::optional> mesh_device = std::nullopt) -> Tensor { + Tensor cpu_tensor; if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { cpu_tensor = tt::tt_metal::tensor_impl::to_host_mesh_tensor_wrapper(tensor, true); diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index b0ce21fc37c..63f5fee11f9 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -216,20 +216,23 @@ def from_torch( else: tensor = ttnn.Tensor(tensor, dtype) + strategy = {} + tilize_input = [] + if mesh_mapper: if isinstance(mesh_mapper, ttnn.CppTensorToMesh): tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) - if tile is not None: - tensor = ttnn.Tensor(ttnn.to_torch(tensor), dtype, {}, tile) + tilize_input = ttnn.to_torch(tensor) else: + strategy = mesh_mapper.config() shards = mesh_mapper.map(ttnn.to_torch(tensor)) - if tile is not None: - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config(), tile) - else: - tensor = ttnn.Tensor(shards, dtype, mesh_mapper.config()) - else: - if tile is not None: - tensor = ttnn.Tensor(ttnn.to_torch(tensor), dtype, {}, tile) + tilize_input = shards + if tile is None: + tensor = ttnn.Tensor(tilize_input, dtype, strategy) + + # TODO: find cleaner way of tilizing + if tile is not None: + tensor = ttnn.Tensor(tilize_input, dtype, strategy, tile) if layout is not None and not (dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b): if pad_value is not None: @@ -315,7 +318,7 @@ def to_torch( if isinstance(mesh_composer, ttnn.MeshToTensor): return mesh_composer.compose(tensor) else: - return mesh_composer.compose(ttnn.get_device_tensors(tensor)) + return mesh_composer.compose(ttnn.get_device_tensors(tensor)).to_torch() if tensor.storage_type() == ttnn.DEVICE_STORAGE_TYPE: raise RuntimeError("ttnn.Tensor cannot be on device when converting to torch.Tensor!") From d3614d705ebeeb3f055240c1dce4832b431f32c2 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 4 Mar 2025 22:00:36 +0000 Subject: [PATCH 52/76] fix uint errors --- .../distributed/test_distributed_tensor.py | 72 ++++++++++++++----- ttnn/cpp/ttnn/distributed/api.cpp | 10 +-- .../ttnn/distributed/distributed_tensor.cpp | 15 ++++ ttnn/ttnn/operations/core.py | 1 - 4 files changed, 72 insertions(+), 26 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 1bfa519c5e4..fc51d062eab 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -23,7 +23,10 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppReplicateTensorToMesh(mesh_device) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) replicated_tensors = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -34,6 +37,12 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): out_tensors = ttnn.get_device_tensors(replicated_tensors) + # out_pass1, out_pcc1 = comp_pcc(torch_tensor, ttnn.to_torch(ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT,mesh_device=mesh_device)), pcc=0.99) + # print("test") + # print(out_pass1) + # print(out_pcc1) + # assert out_pass1 + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensors[0]), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -45,7 +54,10 @@ def test_direct_shard_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) sharded_tensor = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -70,7 +82,11 @@ def test_direct_shard_to_tensor_mesh(mesh_device, dtype): def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, M, K) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case @@ -109,7 +125,10 @@ def test_direct_concat_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) sharded_tensor = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -118,15 +137,12 @@ def test_direct_concat_to_tensor_mesh(mesh_device, dtype): device=mesh_device, ) - composer = ttnn.CppConcatMeshToTensor(mesh_device, dim=3) + composer = ttnn.CppConcatMeshToTensor(dim=3) concat_tensor = ttnn.to_torch(sharded_tensor, mesh_composer=composer) out_pass, out_pcc = comp_pcc(torch_tensor, concat_tensor, pcc=0.99) logger.info(f"PCC value: {out_pcc}") - print("attempt") - print(torch_tensor) - print(concat_tensor) assert out_pass @@ -141,7 +157,11 @@ def test_direct_concat_to_tensor_mesh(mesh_device, dtype): def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, M, K) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) + core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case @@ -170,7 +190,7 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) device=mesh_device, ) - composer = ttnn.CppConcat2dMeshToTensor(mesh_device, mesh_shape, dim=concat_dim) + composer = ttnn.CppConcat2dMeshToTensor(mesh_device, dims=concat_dim) concat_tensor = ttnn.to_torch(sharded_tensor, mesh_composer=composer) @@ -190,7 +210,10 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) to_repl = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -211,7 +234,10 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -234,7 +260,10 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -257,7 +286,10 @@ def test_concat_to_tensor(mesh_device, dtype): def test_concat_slice_to_tensor(mesh_device, dtype): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, 32, 256) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + else: + torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, @@ -291,7 +323,10 @@ def test_concat_slice_to_tensor(mesh_device, dtype): def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, M, K) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case @@ -338,7 +373,10 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) - torch_tensor = torch.randn(1, 1, M, K) + if dtype == ttnn.uint16: + torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + else: + torch_tensor = torch.randn(1, 1, M, K) core_grid = ttnn.CoreGrid(y=1, x=8) # If K < N it's FF1-like test case, else FF2-like test case @@ -370,6 +408,6 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) - out_pass, out_pcc = comp_pcc(torch_tensor, tnn.to_torch(out_tensor), pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 421e9ca3540..9de607d505c 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -103,15 +103,9 @@ Tensor aggregate_as_tensor( specs.push_back(shard.get_tensor_spec()); auto visitor = tt::stl::overloaded{[&shard, &host_owned_buffers](const auto& buffer) -> OwnedBuffer { - using BufferType = std::decay_t; - using ValueType = typename BufferType::value_type; + using BorrowedBufferType = std::vector::value_type>; - std::vector physical_data(buffer.begin(), buffer.end()); - - std::vector logical_data = - tensor_impl::decode_tensor_data(std::move(physical_data), shard.get_tensor_spec()); - - return owned_buffer::create(std::move(logical_data)); + return owned_buffer::create(BorrowedBufferType(buffer.begin(), buffer.end())); }}; host_owned_buffers.push_back(std::visit(visitor, buffer)); diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index a6bb24f4211..72cc919b99d 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -142,6 +142,21 @@ class Concat2dMeshToTensor : public MeshToTensor { } // namespace +std::vector TensorToMesh::map(const Tensor& tensor) const { + // This function should never be called directly, it's just to satisfy the linker + TT_THROW("Pure virtual function 'map' called - please use or define concrete implementations instead."); +} + +tt::tt_metal::DistributedTensorConfig TensorToMesh::config() const { + // This function should never be called directly, it's just to satisfy the linker + TT_THROW("Pure virtual function 'config' called - please use or define concrete implementations instead."); +} + +Tensor MeshToTensor::compose(const std::vector& tensors) const { + // This function should never be called directly, it's just to satisfy the linker + TT_THROW("Pure virtual function 'compose' called - please use or define concrete implementations instead."); +} + std::unique_ptr replicate_tensor_to_mesh_mapper(MeshDevice& mesh_device) { return std::make_unique(mesh_device.num_devices()); } diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 63f5fee11f9..3141b2da59b 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -342,7 +342,6 @@ def to_torch( if tensor.shape[0] != 1: raise RuntimeError("ttnn: Unable to squeeze to desired rank!") tensor = tensor.squeeze(0) - torch_tensor = TorchTensor(tensor) if dtype is not None: From cac11212e4ab292be4ac363b7b8d54151acaf5e3 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 4 Mar 2025 22:47:31 +0000 Subject: [PATCH 53/76] fix out of bounds error --- .../distributed/test_distributed_tensor.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index fc51d062eab..0e59d03ad66 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -24,7 +24,7 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppReplicateTensorToMesh(mesh_device) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) replicated_tensors = ttnn.from_torch( @@ -37,12 +37,6 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): out_tensors = ttnn.get_device_tensors(replicated_tensors) - # out_pass1, out_pcc1 = comp_pcc(torch_tensor, ttnn.to_torch(ttnn.from_torch(torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT,mesh_device=mesh_device)), pcc=0.99) - # print("test") - # print(out_pass1) - # print(out_pcc1) - # assert out_pass1 - out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(out_tensors[0]), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -55,7 +49,7 @@ def test_direct_shard_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) sharded_tensor = ttnn.from_torch( @@ -83,7 +77,7 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) @@ -126,7 +120,7 @@ def test_direct_concat_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) sharded_tensor = ttnn.from_torch( @@ -158,7 +152,7 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) @@ -211,7 +205,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_repl = ttnn.from_torch( @@ -235,7 +229,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( @@ -261,7 +255,7 @@ def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( @@ -287,7 +281,7 @@ def test_concat_slice_to_tensor(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( @@ -324,7 +318,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) core_grid = ttnn.CoreGrid(y=1, x=8) @@ -374,7 +368,7 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 65535, (1, 1, M, K)) + torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) core_grid = ttnn.CoreGrid(y=1, x=8) From befdcc5841c5d85e7805bb534d98693fc4541160 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 4 Mar 2025 22:50:21 +0000 Subject: [PATCH 54/76] actual fix with correct copy paste --- .../distributed/test_distributed_tensor.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 0e59d03ad66..fc4a7a9efba 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -24,7 +24,7 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppReplicateTensorToMesh(mesh_device) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) replicated_tensors = ttnn.from_torch( @@ -49,7 +49,7 @@ def test_direct_shard_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) sharded_tensor = ttnn.from_torch( @@ -77,7 +77,7 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) @@ -120,7 +120,7 @@ def test_direct_concat_to_tensor_mesh(mesh_device, dtype): mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) sharded_tensor = ttnn.from_torch( @@ -152,7 +152,7 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) @@ -205,7 +205,7 @@ def test_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_repl = ttnn.from_torch( @@ -229,7 +229,7 @@ def test_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( @@ -255,7 +255,7 @@ def test_concat_to_tensor(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( @@ -281,7 +281,7 @@ def test_concat_slice_to_tensor(mesh_device, dtype): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, 32, 256)) + torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) else: torch_tensor = torch.randn(1, 1, 32, 256) to_shard = ttnn.from_torch( @@ -318,7 +318,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) core_grid = ttnn.CoreGrid(y=1, x=8) @@ -368,7 +368,7 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): torch.manual_seed(1234) if dtype == ttnn.uint16: - torch_tensor = torch.randint(0, 39990, (1, 1, M, K)) + torch_tensor = torch.randint(0, 32767, (1, 1, M, K)) else: torch_tensor = torch.randn(1, 1, M, K) core_grid = ttnn.CoreGrid(y=1, x=8) From a5c6847876e89d133c4424e781f70cc3d501013f Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 15:45:06 +0000 Subject: [PATCH 55/76] make the switch, satisfy linker without dummy virtual function definitions --- .../distributed/test_distributed_tensor.py | 10 +- .../ttnn/distributed/distributed_pybind.cpp | 39 +--- .../ttnn/distributed/distributed_tensor.cpp | 23 --- .../ttnn/distributed/distributed_tensor.hpp | 12 +- ttnn/ttnn/__init__.py | 10 +- ttnn/ttnn/distributed/__init__.py | 5 - ttnn/ttnn/distributed/distributed.py | 169 ------------------ ttnn/ttnn/operations/core.py | 23 +-- 8 files changed, 29 insertions(+), 262 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index fc4a7a9efba..216ac1364e3 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -21,7 +21,7 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - mapper = ttnn.CppReplicateTensorToMesh(mesh_device) + mapper = ttnn.ReplicateTensorToMesh(mesh_device) if dtype == ttnn.uint16: torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) @@ -46,7 +46,7 @@ def test_direct_replicate_to_tensor_mesh(mesh_device, dtype): def test_direct_shard_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) if dtype == ttnn.uint16: torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) @@ -97,7 +97,7 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - mapper = ttnn.CppShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) sharded_tensor = ttnn.from_torch( torch_tensor, @@ -117,7 +117,7 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): def test_direct_concat_to_tensor_mesh(mesh_device, dtype): torch.manual_seed(1234) - mapper = ttnn.CppShardTensorToMesh(mesh_device, dim=3) + mapper = ttnn.ShardTensorToMesh(mesh_device, dim=3) if dtype == ttnn.uint16: torch_tensor = torch.randint(0, 32767, (1, 1, 32, 256)) @@ -173,7 +173,7 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) use_height_and_width_as_shard_shape=True, ) - mapper = ttnn.CppShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) sharded_tensor = ttnn.from_torch( torch_tensor, diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index ba067a8f610..7b9f2c542bb 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -58,13 +58,12 @@ class Concat2dMeshToTensor : public MeshToTensor {}; void py_module_types(py::module& module) { py::class_>(module, "CppMeshToTensor"); - py::class_>(module, "CppTensorToMesh"); + py::class_>(module, "TensorToMesh"); py::class_>( - module, "CppReplicateTensorToMesh"); - py::class_>(module, "CppShardTensorToMesh"); - py::class_>( - module, "CppShardTensorTo2dMesh"); + module, "ReplicateTensorToMesh"); + py::class_>(module, "ShardTensorToMesh"); + py::class_>(module, "ShardTensorTo2dMesh"); py::class_>(module, "CppConcatMeshToTensor"); py::class_>( module, "CppConcat2dMeshToTensor"); @@ -415,13 +414,13 @@ void py_module(py::module& module) { )doc"); auto py_tensor_to_mesh = static_cast>>( - module.attr("CppTensorToMesh")); + module.attr("TensorToMesh")); py_tensor_to_mesh .def(py::init([]() -> std::unique_ptr { return std::make_unique(); })) .def("map", &TensorToMesh::map) .def("config", &TensorToMesh::config); auto py_replicate_tensor_to_mesh = - static_cast>>(module.attr("CppReplicateTensorToMesh")); + static_cast>>(module.attr("ReplicateTensorToMesh")); py_replicate_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device) -> std::unique_ptr { @@ -432,7 +431,7 @@ void py_module(py::module& module) { "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &TensorToMesh::config); auto py_shard_tensor_to_mesh = - static_cast>>(module.attr("CppShardTensorToMesh")); + static_cast>>(module.attr("ShardTensorToMesh")); py_shard_tensor_to_mesh .def( py::init([](MeshDevice& mesh_device, int dim) -> std::unique_ptr { @@ -444,7 +443,7 @@ void py_module(py::module& module) { "map", [](const TensorToMesh& self, const Tensor& tensor) { return self.map(tensor); }, py::arg("tensor")) .def("config", &TensorToMesh::config); auto py_shard_tensor_to_2d_mesh = - static_cast>>(module.attr("CppShardTensorTo2dMesh")); + static_cast>>(module.attr("ShardTensorTo2dMesh")); py_shard_tensor_to_2d_mesh .def( py::init( @@ -579,28 +578,6 @@ void py_module(py::module& module) { "item": "field", } )doc"); - module.def( - "get_shard2d_config", - &get_shard2d_config, - py::arg("metadata"), - R"doc( - Returns a Shard2dConfig object given a valid metadata object of the type - { - "row_dim": "field", - "col_dim": "field", - } - )doc"); - module.def( - "get_concat2d_config", - &get_concat2d_config, - py::arg("metadata"), - R"doc( - Returns a Concat2dConfig object given a valid metadata object of the type - { - "row_dim": "field", - "col_dim": "field", - } - )doc"); module.def( "get_device_tensor", py::overload_cast(&ttnn::distributed::get_device_tensor), diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 72cc919b99d..19646bf8bb1 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -142,21 +142,6 @@ class Concat2dMeshToTensor : public MeshToTensor { } // namespace -std::vector TensorToMesh::map(const Tensor& tensor) const { - // This function should never be called directly, it's just to satisfy the linker - TT_THROW("Pure virtual function 'map' called - please use or define concrete implementations instead."); -} - -tt::tt_metal::DistributedTensorConfig TensorToMesh::config() const { - // This function should never be called directly, it's just to satisfy the linker - TT_THROW("Pure virtual function 'config' called - please use or define concrete implementations instead."); -} - -Tensor MeshToTensor::compose(const std::vector& tensors) const { - // This function should never be called directly, it's just to satisfy the linker - TT_THROW("Pure virtual function 'compose' called - please use or define concrete implementations instead."); -} - std::unique_ptr replicate_tensor_to_mesh_mapper(MeshDevice& mesh_device) { return std::make_unique(mesh_device.num_devices()); } @@ -213,12 +198,4 @@ Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer) { : composer.compose({tensor}); } -Shard2dConfig get_shard2d_config(const std::unordered_map& metadata) { - return Shard2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); -} - -Concat2dConfig get_concat2d_config(const std::unordered_map& metadata) { - return Concat2dConfig(std::stoi(metadata.at("row_dim")), std::stoi(metadata.at("col_dim"))); -} - } // namespace ttnn::distributed diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp index d7705544e2a..b8fd7e8003e 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.hpp @@ -17,18 +17,20 @@ namespace ttnn::distributed { // Mapper interface that distributes a host tensor onto a multi-device configuration. +// The __attribute__((weak)) instructs pybind imports not to look for a symbol for these functions, as the linker won't +// create one. class TensorToMesh { public: virtual ~TensorToMesh() = default; - virtual std::vector map(const Tensor& tensor) const = 0; - virtual tt::tt_metal::DistributedTensorConfig config() const = 0; + virtual __attribute__((weak)) std::vector map(const Tensor& tensor) const = 0; + virtual __attribute__((weak)) tt::tt_metal::DistributedTensorConfig config() const = 0; }; // Composer interface that aggregates a multi-device tensor into a host tensor. class MeshToTensor { public: virtual ~MeshToTensor() = default; - virtual Tensor compose(const std::vector& tensors) const = 0; + virtual __attribute__((weak)) Tensor compose(const std::vector& tensors) const = 0; }; struct Shard2dConfig { @@ -69,8 +71,4 @@ Tensor distribute_tensor( // Aggregates a multi-device tensor into a host tensor according to the `composer`. Tensor aggregate_tensor(const Tensor& tensor, const MeshToTensor& composer); -Shard2dConfig get_shard2d_config(const std::unordered_map& metadata); - -Concat2dConfig get_concat2d_config(const std::unordered_map& metadata); - } // namespace ttnn::distributed diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index dce198ccc88..496f7d8c5f9 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -97,10 +97,10 @@ def manage_config(name, value): from ttnn._ttnn.multi_device import ( MeshDevice, CppMeshToTensor, - CppTensorToMesh, - CppReplicateTensorToMesh, - CppShardTensorToMesh, - CppShardTensorTo2dMesh, + TensorToMesh, + ReplicateTensorToMesh, + ShardTensorToMesh, + ShardTensorTo2dMesh, CppConcatMeshToTensor, CppConcat2dMeshToTensor, ReplicateTensor, @@ -111,8 +111,6 @@ def manage_config(name, value): DistributedTensorConfig, get_device_tensor, get_device_tensors, - get_shard2d_config, - get_concat2d_config, get_distributed_tensor_config, aggregate_as_tensor, replicate_tensor_to_mesh_mapper, diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 4901c6ae8cb..000b23a2b2d 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -6,10 +6,6 @@ from .distributed import ( MeshDevice, DispatchCoreType, - TensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, MeshToTensor, ConcatMeshToTensor, ConcatMesh2dToTensor, @@ -28,5 +24,4 @@ ConcatMeshToTensor, synchronize_devices, visualize_mesh_device, - distribute, ) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index fa057bd0051..d7a2734d7f2 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -210,22 +210,6 @@ def synchronize_devices( # TODO: All of the TensorTo and MeshTo classes will be slowly cut out over the next few days -class TensorToMesh: - """ - Defines the mapping of a torch.Tensor to a device mesh: e.g. Shard/Replicate. - You can also "Bring your own TensorToMesh" based on your custom mapping. - """ - - def __init__(self, mesh_device): - self.mesh_device = mesh_device - - def map(self, tensor: "torch.Tensor"): - raise NotImplementedError("Subclasses must implement this method") - - def config(self): - raise NotImplementedError("Subclasses must implement this method") - - class MeshToTensor: """ Defines the inverse operation of TensorToMesh. Given a set of per-device @@ -238,107 +222,6 @@ def compose(self, tensor: ttnn.Tensor): raise NotImplementedError("Subclasses must implement this method") -class ShardTensorToMesh(TensorToMesh): - def __init__(self, mesh_device, dim): - super().__init__(mesh_device) - self.shard_dim = dim - - def map(self, tensor: "torch.Tensor") -> Dict[int, ttnn.Tensor]: - import torch - - sliced_tensors = torch.chunk(tensor, self.mesh_device.get_num_devices(), dim=self.shard_dim) - return list(sliced_tensors) - - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.shard_dim}", - } - - -class ShardTensor2dMesh(TensorToMesh): - """ - Shard a tensor across a 2D mesh of devices. - This class implements a strategy for distributing a tensor across a 2D grid of devices, - allowing for efficient parallel processing in distributed computing environments. - """ - - def __init__(self, mesh_device: MeshDevice, mesh_shape: Tuple[int, int], dims: Tuple[Optional[int], Optional[int]]): - """ - Initialize the ShardTensor2dMesh. - Args: - mesh_device: The target device mesh for distributing the tensor. - mesh_shape: The shape of the 2D mesh as (rows, cols). - dims: The dimensions to shard along, specified as (row_dim, col_dim). - The `dims` tuple determines how the tensor is sharded across the 2D mesh: - - row_dim: The dimension to shard across mesh rows (or None for replication). - - col_dim: The dimension to shard across mesh columns (or None for replication). - Examples: - 1. dims=(2, 3) for a tensor of shape (A, B, C, D): - - Shard along dimension 2 (C) across mesh rows - - Shard along dimension 3 (D) across mesh columns - 2. dims=(None, 3): - - Replicate across mesh rows - - Shard along dimension 3 (D) across mesh columns - 3. dims=(None, None): - - Fully replicate the tensor across all devices - """ - super().__init__(mesh_device) - self.mesh_shape: Tuple[int, int] = mesh_shape - self.dims: Tuple[Optional[int], Optional[int]] = dims - - mesh_device_rows, mesh_device_cols = self.mesh_device.shape - if mesh_shape[0] > mesh_device_rows or mesh_shape[1] > mesh_device_cols: - raise ValueError("ShardTensor2dMesh: Device mesh shape does not match the provided mesh shape.") - - def map(self, tensor: "torch.Tensor") -> List["torch.Tensor"]: - """ - Map the input tensor to a list of sharded tensors. - Args: - tensor: The input tensor to be sharded. - Returns: - A list of sharded tensors, one for each device in the mesh. - Raises: - ValueError: If the number of sharding dimensions is not 2. - """ - import torch - - if len(self.dims) != 2: - raise ValueError("ShardTensor2dMesh only supports 2D shard dimensions") - - rows, cols = self.mesh_shape - row_dim, col_dim = self.dims - - # Shard along rows - row_tensors = ( - [tensor.clone() for _ in range(rows)] if row_dim is None else torch.chunk(tensor, rows, dim=row_dim) - ) - - # Shard along columns - if col_dim is None: - return [t.clone() for t in row_tensors for _ in range(cols)] - tensor_shards = [tt for t in row_tensors for tt in torch.chunk(t, cols, dim=col_dim)] - - if len(tensor_shards) != rows * cols: - raise ValueError( - f"ShardTensor2dMesh: Sharding failed. Number of shards should match the product of the mesh dimensions. Got {len(tensor_shards)} shards but expected {rows * cols} ({rows} rows * {cols} cols)." - ) - - return tensor_shards - - def config(self) -> Dict[str, str]: - """ - Provide the configuration of the sharding strategy. - Returns: - A dictionary containing the sharding strategy and dimensions. - """ - return { - "strategy": "shard_2d", - "mesh_shape_y": str(self.mesh_shape[0]), - "mesh_shape_x": str(self.mesh_shape[1]), - } - - class ConcatMesh2dToTensor(MeshToTensor): """ Concatenate tensors from a 2D mesh back into a single tensor. @@ -397,20 +280,6 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": return torch.cat(row_concatenated, dim=row_dim) -class ReplicateTensorToMesh(TensorToMesh): - def __init__(self, mesh_device: MeshDevice): - super().__init__(mesh_device) - - def map(self, tensor: "torch.Tensor"): - return [tensor for i in range(self.mesh_device.get_num_devices())] - - def config(self): - return { - "strategy": "replicate", - "replication_factor": str(self.mesh_device.get_num_devices()), - } - - class ConcatMeshToTensor(MeshToTensor): def __init__(self, mesh_device: MeshDevice, dim: int): self.concat_dim = dim @@ -425,42 +294,4 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) -@contextlib.contextmanager -def distribute(default: Union[TensorToMesh, MeshToTensor]): - """ - Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified - mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. - Invocations of ttnn.from_torch(..) will use the mesh_mapper as defined by the default in ttnn.distribute. - Invocations of ttnn.to_torch(..) will use the mesh_composer as defined by the default in ttnn.distribute. - - Args: - mesh_mapper_or_composer (Union[TensorToMesh, MeshToTensor]): An instance of either TensorToMesh or MeshToTensor - used to map tensors to a mesh or compose tensors from a mesh. - - Example: - with distribute(ShardTensorToMesh(mesh_device, dim=3)): - # Code here will use the default mapper - result = ttnn.from_torch(torch_tensor) - - is equivalent to: - result = ttnn.from_torch(torch_tensor, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3)) - """ - _original_to_torch = ttnn.to_torch - _original_from_torch = ttnn.from_torch - - try: - if isinstance(default, TensorToMesh): - ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, MeshToTensor): - ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) - else: - raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") - yield - - finally: - # Restore the original functions - ttnn.from_torch = _original_from_torch - ttnn.to_torch = _original_to_torch - - __all__ = [] diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 3141b2da59b..47df2cb2719 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -156,7 +156,7 @@ def from_torch( layout: Optional[ttnn.Layout] = ttnn.ROW_MAJOR_LAYOUT, device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, - mesh_mapper: Optional[Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh]] = None, + mesh_mapper: Optional[ttnn.TensorToMesh] = None, cq_id: Optional[int] = ttnn.DefaultQueueId, ) -> ttnn.Tensor: """ @@ -220,15 +220,8 @@ def from_torch( tilize_input = [] if mesh_mapper: - if isinstance(mesh_mapper, ttnn.CppTensorToMesh): - tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) - tilize_input = ttnn.to_torch(tensor) - else: - strategy = mesh_mapper.config() - shards = mesh_mapper.map(ttnn.to_torch(tensor)) - tilize_input = shards - if tile is None: - tensor = ttnn.Tensor(tilize_input, dtype, strategy) + tensor = ttnn.distribute_tensor(tensor, mesh_mapper, device) + tilize_input = ttnn.to_torch(tensor) # TODO: find cleaner way of tilizing if tile is not None: @@ -535,7 +528,7 @@ def as_tensor( memory_config: Optional[ttnn.MemoryConfig] = None, cache_file_name: Optional[Union[str, pathlib.Path]] = None, preprocess: Optional[Callable[[ttnn.Tensor], ttnn.Tensor]] = None, - mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh] = None, + mesh_mapper: Optional[ttnn.TensorToMesh] = None, use_device_tilizer: bool = False, ) -> ttnn.Tensor: """ @@ -587,7 +580,7 @@ def torch_to_ttnn( layout: Optional[ttnn.Layout], device: Optional[ttnn.Device], memory_config: Optional[ttnn.MemoryConfig], - mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh], + mesh_mapper: Optional[ttnn.TensorToMesh], ): if preprocess: tensor = preprocess(tensor) @@ -620,7 +613,7 @@ def from_torch_and_dump( dtype: Optional[ttnn.DataType], layout: Optional[ttnn.Layout], cache_file_name: str, - mesh_mapper: Union[ttnn.TensorToMesh, ttnn.CppTensorToMesh], + mesh_mapper: Optional[ttnn.TensorToMesh], ): tensor = torch_to_ttnn(tensor, dtype, layout, device, memory_config, mesh_mapper) logger.debug( @@ -631,9 +624,7 @@ def from_torch_and_dump( ttnn._ttnn.tensor.dump_tensor(cache_file_name, tensor, distributed_config) return tensor - if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh) or isinstance( - mesh_mapper, ttnn.CppReplicateTensorToMesh - ): + if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh): storage_type = f"_multi_device" if mesh_mapper else "" elif mesh_mapper: storage_type = f"_multi_device_{device.get_num_devices()}" From f4c994fa718a3edb966db90d4ae82aa40128d4c9 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 16:46:00 +0000 Subject: [PATCH 56/76] remove replicate distinction --- ttnn/ttnn/operations/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 47df2cb2719..0de3c995db9 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -624,9 +624,7 @@ def from_torch_and_dump( ttnn._ttnn.tensor.dump_tensor(cache_file_name, tensor, distributed_config) return tensor - if isinstance(mesh_mapper, ttnn.ReplicateTensorToMesh): - storage_type = f"_multi_device" if mesh_mapper else "" - elif mesh_mapper: + if mesh_mapper: storage_type = f"_multi_device_{device.get_num_devices()}" else: storage_type = "" From 08739b6f5758577bf98451ab702a6080abcf2beb Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 18:17:43 +0000 Subject: [PATCH 57/76] remove tensortomesh from distributed.py imports --- ttnn/ttnn/distributed/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 000b23a2b2d..291e8cd0ad7 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -16,10 +16,6 @@ get_pcie_device_ids, get_device_ids, create_mesh_device, - TensorToMesh, - ShardTensorToMesh, - ShardTensor2dMesh, - ReplicateTensorToMesh, MeshToTensor, ConcatMeshToTensor, synchronize_devices, From a707c0ba66b9ab985a1d3a00eb7fbb23ab197ffc Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 18:18:07 +0000 Subject: [PATCH 58/76] remove duplicate meshtotensor imports --- ttnn/ttnn/distributed/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 291e8cd0ad7..9306b213785 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -16,8 +16,6 @@ get_pcie_device_ids, get_device_ids, create_mesh_device, - MeshToTensor, - ConcatMeshToTensor, synchronize_devices, visualize_mesh_device, ) From 4ca99817a69192c2d563e38de5fdca93b3803000 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 18:43:37 +0000 Subject: [PATCH 59/76] fix syntax error for shard --- tests/ttnn/distributed/test_distributed_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 216ac1364e3..de6dae6d0ff 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -97,7 +97,7 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): use_height_and_width_as_shard_shape=True, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.ShardTensorTo2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) sharded_tensor = ttnn.from_torch( torch_tensor, @@ -173,7 +173,7 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) use_height_and_width_as_shard_shape=True, ) - mapper = ttnn.ShardTensor2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) + mapper = ttnn.ShardTensorTo2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) sharded_tensor = ttnn.from_torch( torch_tensor, From 317f873fa63c5a235f029a56fa81d9e7c29b82b9 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 20:01:18 +0000 Subject: [PATCH 60/76] fix test syntax error --- tests/ttnn/distributed/test_distributed_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index de6dae6d0ff..3004cea3ce3 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -108,7 +108,7 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): device=mesh_device, ) - out_pass, out_pcc = comp_pcc(ttnn.to_torch(torch_tensor, sharded_tensor), pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(sharded_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass From 04dfba698339a068e95b3da8ea7a79a1d6e240c1 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 20:53:03 +0000 Subject: [PATCH 61/76] improved shape error message --- ttnn/cpp/ttnn/distributed/distributed_tensor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp index 19646bf8bb1..79d3377d584 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_tensor.cpp @@ -160,7 +160,7 @@ std::unique_ptr shard_tensor_to_2d_mesh_mapper( TT_FATAL( mesh_shape[0] <= mesh_device.shape()[0] && // mesh_shape[1] <= mesh_device.shape()[1], - "Device mesh shape does not match the provided mesh shape."); + "Device mesh shape {} does not match the provided mesh shape ({}, {}).", mesh_device.shape(), mesh_shape[0], mesh_shape[1]); return std::make_unique(mesh_shape[0], mesh_shape[1], config); } From c16ed9198ada1f9127ab8ed1c423783d973ceb8e Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 21:06:20 +0000 Subject: [PATCH 62/76] syntax fix --- tests/ttnn/distributed/test_distributed_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index 3004cea3ce3..f7f16e0848c 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -349,9 +349,9 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): shards = ttnn.get_device_tensors(ttnn.distribute_tensor(to_shard, mapper, mesh_device)) - ttnn.aggregate_as_tensor(shards) + sharded_tensor = ttnn.aggregate_as_tensor(shards) - out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(shards), pcc=0.99) + out_pass, out_pcc = comp_pcc(torch_tensor, ttnn.to_torch(sharded_tensor), pcc=0.99) logger.info(f"PCC value: {out_pcc}") assert out_pass @@ -361,7 +361,7 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): ) @pytest.mark.parametrize( "M, K, N", - [pytest.param(32, 128, 64), pytest.param(32, 128, 64)], + [pytest.param(32, 64, 128), pytest.param(32, 128, 64)], ) @pytest.mark.parametrize("dtype", [ttnn.uint16, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b, ttnn.float32]) def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): From b548526f09be59fd38b0987640958f4f6707b45a Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 21:11:44 +0000 Subject: [PATCH 63/76] rationalize composer check and method signature --- tests/ttnn/distributed/test_distributed_tensor.py | 2 +- ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index f7f16e0848c..b37d6c39be4 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -398,7 +398,7 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): mapper = ttnn.shard_tensor_to_2d_mesh_mapper(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) - composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim, mesh_shape=mesh_shape) + composer = ttnn.concat_2d_mesh_to_tensor_composer(mesh_device, dims=concat_dim) out_tensor = ttnn.aggregate_tensor(ttnn.distribute_tensor(to_shard, mapper, mesh_device), composer) diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 7b9f2c542bb..5d44ac3a514 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -651,18 +651,12 @@ void py_module(py::module& module) { module.def( "concat_2d_mesh_to_tensor_composer", [](MeshDevice& mesh_device, - const std::tuple mesh_shape, const std::tuple dims) -> std::unique_ptr { - TT_FATAL( - std::get<0>(mesh_shape) <= mesh_device.shape()[0] && // - std::get<1>(mesh_shape) <= mesh_device.shape()[1], - "Device mesh shape does not match the provided mesh shape."); return concat_2d_mesh_to_tensor_composer( mesh_device, Concat2dConfig{.row_dim = std::get<0>(dims), .col_dim = std::get<1>(dims)}); }, py::arg("mesh_device"), py::arg("dims"), - py::arg("mesh_shape"), R"doc( Create a Concat2dMeshToTensor composer with the given mesh device and dimensions. From 2bc7f46987ab2c0d5ce0879be36f305e67ab50c8 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 21:13:37 +0000 Subject: [PATCH 64/76] fix composer path --- ttnn/ttnn/operations/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 0de3c995db9..259d9abd685 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -311,7 +311,7 @@ def to_torch( if isinstance(mesh_composer, ttnn.MeshToTensor): return mesh_composer.compose(tensor) else: - return mesh_composer.compose(ttnn.get_device_tensors(tensor)).to_torch() + return ttnn.aggregate_tensor(tensor, mesh_composer).to_torch() if tensor.storage_type() == ttnn.DEVICE_STORAGE_TYPE: raise RuntimeError("ttnn.Tensor cannot be on device when converting to torch.Tensor!") From 0ee4c22530c975ceaeff8289d94ef8f2b4baac83 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 21:57:45 +0000 Subject: [PATCH 65/76] fix memoryconfig error --- .../distributed/test_distributed_tensor.py | 40 ++----------------- ttnn/ttnn/operations/core.py | 6 +-- 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/tests/ttnn/distributed/test_distributed_tensor.py b/tests/ttnn/distributed/test_distributed_tensor.py index b37d6c39be4..7ac3a22b677 100644 --- a/tests/ttnn/distributed/test_distributed_tensor.py +++ b/tests/ttnn/distributed/test_distributed_tensor.py @@ -89,21 +89,13 @@ def test_direct_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - sharded_mem_config = ttnn.create_sharded_memory_config( - shape=(M // core_grid.y, K // core_grid.x), - core_grid=core_grid, - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - mapper = ttnn.ShardTensorTo2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) sharded_tensor = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, - memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=mapper, device=mesh_device, ) @@ -165,21 +157,13 @@ def test_direct_concat2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device) K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - sharded_mem_config = ttnn.create_sharded_memory_config( - shape=(M // core_grid.y, K // core_grid.x), - core_grid=core_grid, - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - mapper = ttnn.ShardTensorTo2dMesh(mesh_device, mesh_shape=mesh_shape, dims=shard_dim) sharded_tensor = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, - memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=mapper, device=mesh_device, ) @@ -329,19 +313,11 @@ def test_shard2d_to_tensor_mesh(M, K, N, dtype, mesh_shape, mesh_device): K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - sharded_mem_config = ttnn.create_sharded_memory_config( - shape=(M // core_grid.y, K // core_grid.x), - core_grid=core_grid, - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, - memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + memory_config=ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, ) @@ -380,19 +356,11 @@ def test_concat2d_to_tensor(M, K, N, dtype, mesh_shape, mesh_device): K = K // mesh_shape[1] if K < N else K // mesh_shape[0] N = N // mesh_shape[0] if K < N else N // mesh_shape[1] - sharded_mem_config = ttnn.create_sharded_memory_config( - shape=(M // core_grid.y, K // core_grid.x), - core_grid=core_grid, - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - to_shard = ttnn.from_torch( torch_tensor, dtype=dtype, layout=ttnn.TILE_LAYOUT, - memory_config=sharded_mem_config if M == 32 else ttnn.DRAM_MEMORY_CONFIG, + memory_config=ttnn.DRAM_MEMORY_CONFIG, device=mesh_device, ) diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 259d9abd685..c4c6933078a 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -208,7 +208,7 @@ def from_torch( if dtype == ttnn.bfloat8_b or dtype == ttnn.bfloat4_b: if layout != ttnn.TILE_LAYOUT: raise RuntimeError("ttnn.from_torch: bfloat8_b/bfloat4_b requires TILE_LAYOUT!") - # Tilize tensor, TODO: this is incredibly non-performant when done on host + # Tilize tensor, TODO: this is non-performant when done on host tensor = ttnn.from_torch(tensor, layout=ttnn.TILE_LAYOUT, tile=tile, pad_value=pad_value, mesh_mapper=None) logical_shape = tensor.shape padded_shape = tensor.padded_shape @@ -231,14 +231,14 @@ def from_torch( if pad_value is not None: tensor = tensor.pad_to_tile(pad_value) if ttnn.is_tensor_storage_on_device(tensor): - # TODO: support tilizing non bfloat/float types on device tensors making this expensive conversion unnecessary + # TODO: support tilizing non bfloat/float types on device tensors making this conversion unnecessary tensor = ttnn.from_device(tensor, cq_id=cq_id) tensor = ttnn.to_layout(tensor, layout, device=device) if device is not None: if memory_config is None: memory_config = ttnn.DRAM_MEMORY_CONFIG - # Handle sharding case which will have already output to a multidevice + # Handle sharding case which would have already output to a multidevice if not ttnn.is_tensor_storage_on_device(tensor): tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id) From 954ef6c5d0f154dc32faf7fc78a4822b3c6face4 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Wed, 5 Mar 2025 22:04:16 +0000 Subject: [PATCH 66/76] cleanup --- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 65ddf493197..5001117f885 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -29,9 +29,6 @@ namespace tt::tt_metal::tensor_ops { Tensor tensor_to_device( const Tensor& input_tensor, IDevice* target_device, const MemoryConfig& mem_config, QueueId cq_id) { - // TODO: remove - std::cout << "debugprint2" << std::endl; - ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to_device", input_tensor, target_device, mem_config); // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. @@ -197,8 +194,6 @@ Tensor tensor_to_layout(const Tensor& input_tensor, Layout target_layout, distri host_storage != nullptr) { distributed_config = host_storage->strategy; } - // TODO: remove, check the tilize tomorrow - Tensor tensor_modified_layout = Tensor(workers.size(), distributed_config); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { auto& worker = workers[worker_index]; From 0f8d03824b79d8569e4f6e18f224079f753cbdad Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 15:48:40 +0000 Subject: [PATCH 67/76] add back distributed.py since it has uses --- ttnn/ttnn/distributed/distributed.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index d7a2734d7f2..8fd6ef4d848 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -293,5 +293,41 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": ] return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) +@contextlib.contextmanager +def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor]): + """ + Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified + mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. + Invocations of ttnn.from_torch(..) will use the mesh_mapper as defined by the default in ttnn.distribute. + Invocations of ttnn.to_torch(..) will use the mesh_composer as defined by the default in ttnn.distribute. + + Args: + mesh_mapper_or_composer (Union[TensorToMesh, MeshToTensor]): An instance of either TensorToMesh or MeshToTensor + used to map tensors to a mesh or compose tensors from a mesh. + + Example: + with distribute(ShardTensorToMesh(mesh_device, dim=3)): + # Code here will use the default mapper + result = ttnn.from_torch(torch_tensor) + + is equivalent to: + result = ttnn.from_torch(torch_tensor, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3)) + """ + _original_to_torch = ttnn.to_torch + _original_from_torch = ttnn.from_torch + + try: + if isinstance(default, ttnn.TensorToMesh) or isinstance(default, ttnn.MeshToTensor): + ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) + elif isinstance(default, MeshToTensor): + ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) + else: + raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.") + yield + + finally: + # Restore the original functions + ttnn.from_torch = _original_from_torch + ttnn.to_torch = _original_to_torch __all__ = [] From e416e17d47f6b4eef5577db4471d81ec88d8ac6d Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 16:18:24 +0000 Subject: [PATCH 68/76] change llama_common based tests over --- .../t3000/llama2_70b/tests/test_llama_mlp.py | 4 +- .../demos/t3000/llama2_70b/tt/llama_common.py | 33 - .../tests/test_llama_attention_galaxy.py | 919 ++++++++++-------- .../tests/test_llama_decoder_galaxy.py | 21 +- .../llama3_70b/tt/llama_attention_galaxy.py | 13 +- .../tg/llama3_70b/tt/llama_decoder_galaxy.py | 10 +- .../llama3_70b/tt/llama_embedding_galaxy.py | 4 +- .../tg/llama3_70b/tt/llama_mlp_galaxy.py | 10 +- .../tg/llama3_70b/tt/llama_model_galaxy.py | 21 +- 9 files changed, 528 insertions(+), 507 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py b/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py index fcb0956fb4a..cc9fa417b24 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_mlp.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_mlp_optimized import TtLlamaMLP_optimized @@ -42,7 +42,7 @@ def tt_llama_mlp_prepare_inputs(llama_mlp_model, x, mode): layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16, device=llama_mlp_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_mlp_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_mlp_model.mesh_device), ) if mode == "decode": diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index 63c8aad8233..a834b18e653 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -28,42 +28,9 @@ UNIT_TEST_START_POS = 0 UNIT_TEST_GENERATION_LENGTH = 20 from ttnn import ( - TensorToMesh, MeshToTensor, ) - -class ShardTensor2dMesh(TensorToMesh): - def __init__(self, mesh_device, dims, cluster_shape): - super().__init__(mesh_device) - self.dims = dims - self.cluster_shape = cluster_shape - - def map(self, tensor: torch.tensor): - # Returns list of tensors to map to row-major ordering of chips in cluster - tensors_grid_y = None - if self.dims[1] == None: - tensors_grid_y = [tensor.clone() for _ in range(self.cluster_shape[1])] - else: - tensors_grid_y = torch.chunk(tensor, self.cluster_shape[1], dim=self.dims[1]) - - tensors_grid_all = None - if self.dims[0] == None: - tensors_grid_all = [t.clone() for t in tensors_grid_y for _ in range(self.cluster_shape[0])] - else: - tensors_grid_all = [ - tt for t in tensors_grid_y for tt in torch.chunk(t, self.cluster_shape[0], dim=self.dims[0]) - ] - - return list(tensors_grid_all) - - def config(self): - return { - "strategy": "shard", - "shard_dim": f"{self.dims[0] if self.dims[0] else self.dims[1]}", - } - - class ConcatMesh2DToTensor(MeshToTensor): def __init__(self, mesh_device, dims, cluster_shape): self.dims = dims diff --git a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py index a2ea1b7c792..eba08c3cb59 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py @@ -1,472 +1,529 @@ # SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + # SPDX-License-Identifier: Apache-2.0 + import pytest from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper import gc + from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env from models.demos.t3000.llama2_70b.reference.llama.llama.model import precompute_freqs_cis from models.demos.t3000.llama2_70b.tt.llama_common import ( - check_mesh_device, - extract_pcc_from_log, - generate_rot_emb, - get_rotation_mat, - gather_cos_sin, - precompute_freqs, - MAX_SEQ_LEN, - MAX_SEQ_LEN_LLAMA3, - BASE_URL, - UNIT_TEST_N_LAYER, - UNIT_TEST_LAYER_NUM, - UNIT_TEST_START_POS, - UNIT_TEST_GENERATION_LENGTH, - comp_pcc, - get_rot_transformation_mat, - should_skip_model_load, - check_kv_cache, - num_to_corerange, - ConcatMesh2DToTensor, - ShardTensor2dMesh, + check_mesh_device, + extract_pcc_from_log, + generate_rot_emb, + get_rotation_mat, + gather_cos_sin, + precompute_freqs, + MAX_SEQ_LEN, + MAX_SEQ_LEN_LLAMA3, + BASE_URL, + UNIT_TEST_N_LAYER, + UNIT_TEST_LAYER_NUM, + UNIT_TEST_START_POS, + UNIT_TEST_GENERATION_LENGTH, + comp_pcc, + get_rot_transformation_mat, + should_skip_model_load, + check_kv_cache, + num_to_corerange, + ConcatMesh2DToTensor, ) + from models.utility_functions import skip_for_grayskull + + class PytorchLlamaAttentionModel(torch.nn.Module): - def __init__(self, hf_reference_model, layer_num, rope_theta): - super().__init__() - self.attention = hf_reference_model.layers[layer_num].attention - self.rope_theta = rope_theta - # Disable dropout - self.attention.eval() - - configuration = hf_reference_model.params - self.n_heads = configuration.n_heads - hidden_dim = configuration.dim - self.head_dim = hidden_dim // self.n_heads - self.max_seq_len = configuration.max_seq_len - - def prepare_inputs(self, x, start_pos): - """ - Prepare inputs for decode mode. Assume that current token is at - start_pos, and KV cache has valid data up to start_pos. - """ - batch = x.size(0) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) - freqs_cis = freqs_cis[start_pos : start_pos + 1] - - attn_mask = torch.zeros(batch, 1, 1, start_pos + 1) - # attn_mask[:, :, :, : start_pos + 1] = -1e9 - attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) - - return x, start_pos, freqs_cis, attn_mask - - def prepare_inputs_prefill(self, x, start_pos): - """ - Prepare inputs for decode mode. Assume that current token is at - start_pos, and KV cache has valid data up to start_pos. - """ - batch = x.size(0) - seq_len = x.size(1) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) - freqs_cis = freqs_cis[start_pos : start_pos + seq_len] - - attn_mask = torch.full((seq_len, seq_len), float("-inf")) - attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask = attn_mask.expand(batch, self.n_heads, -1, -1) - - return x, start_pos, freqs_cis, attn_mask - - def forward(self, x, start_pos, freqs_cis, mask): - """ - x: (batch, seq, hidden_dim) - start_pos: int - freqs_cis: ? - mask: ? - - return: (batch, seq, hidden_dim) - """ - result = self.attention( - x, - start_pos, - freqs_cis, - mask, - ) - return result + def __init__(self, hf_reference_model, layer_num, rope_theta): + super().__init__() + self.attention = hf_reference_model.layers[layer_num].attention + self.rope_theta = rope_theta + # Disable dropout + self.attention.eval() + + + configuration = hf_reference_model.params + self.n_heads = configuration.n_heads + hidden_dim = configuration.dim + self.head_dim = hidden_dim // self.n_heads + self.max_seq_len = configuration.max_seq_len + + + def prepare_inputs(self, x, start_pos): + """ + Prepare inputs for decode mode. Assume that current token is at + start_pos, and KV cache has valid data up to start_pos. + """ + batch = x.size(0) + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) + freqs_cis = freqs_cis[start_pos : start_pos + 1] + + + attn_mask = torch.zeros(batch, 1, 1, start_pos + 1) + # attn_mask[:, :, :, : start_pos + 1] = -1e9 + attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) + + + return x, start_pos, freqs_cis, attn_mask + + + def prepare_inputs_prefill(self, x, start_pos): + """ + Prepare inputs for decode mode. Assume that current token is at + start_pos, and KV cache has valid data up to start_pos. + """ + batch = x.size(0) + seq_len = x.size(1) + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) + freqs_cis = freqs_cis[start_pos : start_pos + seq_len] + + + attn_mask = torch.full((seq_len, seq_len), float("-inf")) + attn_mask = torch.triu(attn_mask, diagonal=1) + attn_mask = attn_mask.expand(batch, self.n_heads, -1, -1) + + + return x, start_pos, freqs_cis, attn_mask + + + def forward(self, x, start_pos, freqs_cis, mask): + """ + x: (batch, seq, hidden_dim) + start_pos: int + freqs_cis: ? + mask: ? + + + return: (batch, seq, hidden_dim) + """ + result = self.attention( + x, + start_pos, + freqs_cis, + mask, + ) + return result + + def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_theta, mode="decode"): - assert len(x.size()) == 3 - batch, seq_len, _ = x.shape - - cache_name = lambda name: llama_attention_model.cache_path / (f"{name}") - - if mode == "decode": - assert seq_len == 1, "Only supporting decode mode" - x = x.transpose(0, 1).unsqueeze(1) - assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size) - - ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - xs = ttnn.as_tensor( - x, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ACT_MEMCFG, - device=llama_attention_model.mesh_device, - mesh_mapper=ShardTensor2dMesh( - llama_attention_model.mesh_device, dims=(3, None), cluster_shape=llama_attention_model.cluster_shape - ), - ) - - batch_size_per_group = llama_attention_model.batch_size_per_device_group - - rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta) - rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group) - assert rot_mat.size() == ( - 1, - batch_size_per_group, - llama_attention_model.head_dim, - llama_attention_model.head_dim, - ) - - shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch_size_per_group)}) - ROT_MAT_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - shard_spec_n_cores_grid, - [ - llama_attention_model.head_dim, - llama_attention_model.head_dim, - ], - ttnn.ShardOrientation.ROW_MAJOR, - ), - ) - rot_mats = ttnn.as_tensor( - rot_mat, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ROT_MAT_MEMCFG, - device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), - ) - - attn_masks = None - - elif mode == "prefill": - assert ( - seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192 - ), "Prefill mode only supports seqlen as a multiple of 256 up to 8k" - assert batch == 1, "prefill mode only supports batch size 1" - x = x.unsqueeze(0) - assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size) - xs = ttnn.as_tensor( - x, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=ShardTensor2dMesh( - llama_attention_model.mesh_device, dims=(3, None), cluster_shape=llama_attention_model.cluster_shape - ), - ) - - cos, sin = precompute_freqs( - llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=False - ) - cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) - assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) - assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) - - cos_gathereds = ttnn.as_tensor( - cos_gathered, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), - ) - sin_gathereds = ttnn.as_tensor( - sin_gathered, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), - ) - - rot_mats = [cos_gathereds, sin_gathereds] - - attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) - attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask = attn_mask.expand(1, batch, -1, -1) - attn_masks = ttnn.as_tensor( - attn_mask, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - ) - - return ( - xs, - start_pos, - rot_mats, - attn_masks, - ) + assert len(x.size()) == 3 + batch, seq_len, _ = x.shape + + + cache_name = lambda name: llama_attention_model.cache_path / (f"{name}") + + + if mode == "decode": + assert seq_len == 1, "Only supporting decode mode" + x = x.transpose(0, 1).unsqueeze(1) + assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size) + + + ACT_MEMCFG = ttnn.create_sharded_memory_config( + shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]), + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + xs = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ACT_MEMCFG, + device=llama_attention_model.mesh_device, + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) + ), + ) + + + batch_size_per_group = llama_attention_model.batch_size_per_device_group + + + rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta) + rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group) + assert rot_mat.size() == ( + 1, + batch_size_per_group, + llama_attention_model.head_dim, + llama_attention_model.head_dim, + ) + + + shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch_size_per_group)}) + ROT_MAT_MEMCFG = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + shard_spec_n_cores_grid, + [ + llama_attention_model.head_dim, + llama_attention_model.head_dim, + ], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + rot_mats = ttnn.as_tensor( + rot_mat, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ROT_MAT_MEMCFG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + + + attn_masks = None + + + elif mode == "prefill": + assert ( + seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192 + ), "Prefill mode only supports seqlen as a multiple of 256 up to 8k" + assert batch == 1, "prefill mode only supports batch size 1" + x = x.unsqueeze(0) + assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size) + xs = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) + ), + ) + + + cos, sin = precompute_freqs( + llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=False + ) + cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) + assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) + assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) + + + cos_gathereds = ttnn.as_tensor( + cos_gathered, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + sin_gathereds = ttnn.as_tensor( + sin_gathered, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + + + rot_mats = [cos_gathereds, sin_gathereds] + + + attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) + attn_mask = torch.triu(attn_mask, diagonal=1) + attn_mask = attn_mask.expand(1, batch, -1, -1) + attn_masks = ttnn.as_tensor( + attn_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + ) + + + return ( + xs, + start_pos, + rot_mats, + attn_masks, + ) + + def run_test_LlamaAttention_inference( - mesh_device, - cluster_shape, - batch, - seq_len, - pcc, - model_config, - llama_version, - ckpt_dir, - tokenizer_path, - cache_path, + mesh_device, + cluster_shape, + batch, + seq_len, + pcc, + model_config, + llama_version, + ckpt_dir, + tokenizer_path, + cache_path, ): - # Prepare paths and devices - skip_model_load = should_skip_model_load() - - # Prepare configs - hugging_face_reference_model = Llama.build( - ckpt_dir, - tokenizer_path, - max_seq_len=MAX_SEQ_LEN if llama_version == "llama2" else MAX_SEQ_LEN_LLAMA3, - max_batch_size=batch, - n_layers=UNIT_TEST_N_LAYER, - skip_model_load=skip_model_load, - ).model - hugging_face_reference_model.eval() - state_dict = hugging_face_reference_model.state_dict() - logger.info(state_dict.keys()) - torch.manual_seed(0) - configuration = hugging_face_reference_model.params - - # PyTorch model -------------------------------------------------------------------- - pytorch_LlamaAttention_model = PytorchLlamaAttentionModel( - hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta - ) - # TT model ------------------------------------------------------------------------- - transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size - - transformation_mats = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), - ) - - tt_LlamaAttention_model = TtLlamaAttention_galaxy( - mesh_device, - cluster_shape, - state_dict, - BASE_URL, - UNIT_TEST_LAYER_NUM, - model_config, - configuration, - transformation_mats, - cache_path=cache_path, - ) - - mode = "decode" if seq_len == 1 else "prefill" - - all_tests_pass, all_pccs = True, [] - if mode == "prefill": - generation_start_pos = 0 - generation_length = 1 - else: - generation_start_pos = UNIT_TEST_START_POS - generation_length = UNIT_TEST_GENERATION_LENGTH - - for i in range(generation_length): - # Prepare input - pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len)) - pt_inp = hugging_face_reference_model.tok_embeddings(pt_inp_ids) - pt_inp_normed = hugging_face_reference_model.layers[UNIT_TEST_LAYER_NUM].attention_norm(pt_inp) - tt_input = pt_inp_normed.clone() - start_pos = generation_start_pos + i - - # PyTorch output -------------------------------------------------------------------- - if mode == "prefill": - attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs_prefill( - pt_inp_normed, start_pos - ) - else: - attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs( - pt_inp_normed, start_pos - ) - - pytorch_out = pytorch_LlamaAttention_model( - attention_input, - start_pos, - freqs_cis, - attn_mask, - ) - - # TT hardware execution ------------------------------------------------------------- - attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs( - tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta, mode=mode - ) - tt_out = tt_LlamaAttention_model( - attention_input, - rot_mat, - start_pos, - attn_mask, - mode=mode, - ) - # tt_out = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(tt_out.cpu())] - - tt_out = ttnn.to_torch( - tt_out, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(3, 1), cluster_shape=cluster_shape) - ) - tt_out = tt_out[:, 0:1, :, :] - tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] - - does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) - logger.info(f"Output: {output_pcc}") - - all_pccs.append(extract_pcc_from_log(output_pcc)) - - if does_pass: - logger.info(f"[start_pos={start_pos}] {llama_version} Attention output Passed!") - else: - logger.warning( - f"[start_pos={start_pos}] {llama_version} Attention output Failed! PCC value is lower than {pcc}" - ) - all_tests_pass = False - - logger.info(f"Average PCC over {len(all_pccs)} tokens: {sum(all_pccs) / len(all_pccs)}") - - # Check kv cache - # PyTorch output -------------------------------------------------------------------- - pytorch_layer_present = [ - pytorch_LlamaAttention_model.attention.cache_k.clone().permute(0, 2, 1, 3)[ - :batch, ... - ], # [batch, n_kv_heads, seq, head_dim] - pytorch_LlamaAttention_model.attention.cache_v.clone().permute(0, 2, 1, 3)[ - :batch, ... - ], # [batch, n_kv_heads, seq, head_dim] - ] - # TT hardware output ---------------------------------------------------------------- - - # concat the pasts by heads - tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] - - tt_layer_present_all = [ - ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[ - :batch, ... - ] - for lp in tt_layer_present_all - ] - - cache_test_pass = check_kv_cache( - pytorch_layer_present, - tt_layer_present_all, - generation_start_pos, - generation_length, - seq_len, - mode == "prefill", - pcc, - ) - - all_tests_pass = all_tests_pass and cache_test_pass - - if all_tests_pass: - logger.info(f"{llama_version} Attention output Passed!") - else: - gc.collect() - logger.warning(f"{llama_version} Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + # Prepare paths and devices + skip_model_load = should_skip_model_load() + + + # Prepare configs + hugging_face_reference_model = Llama.build( + ckpt_dir, + tokenizer_path, + max_seq_len=MAX_SEQ_LEN if llama_version == "llama2" else MAX_SEQ_LEN_LLAMA3, + max_batch_size=batch, + n_layers=UNIT_TEST_N_LAYER, + skip_model_load=skip_model_load, + ).model + hugging_face_reference_model.eval() + state_dict = hugging_face_reference_model.state_dict() + logger.info(state_dict.keys()) + torch.manual_seed(0) + configuration = hugging_face_reference_model.params + + + # PyTorch model -------------------------------------------------------------------- + pytorch_LlamaAttention_model = PytorchLlamaAttentionModel( + hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta + ) + # TT model ------------------------------------------------------------------------- + transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size + + + transformation_mats = ttnn.as_tensor( + transformation_mat_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), + ) + + + tt_LlamaAttention_model = TtLlamaAttention_galaxy( + mesh_device, + cluster_shape, + state_dict, + BASE_URL, + UNIT_TEST_LAYER_NUM, + model_config, + configuration, + transformation_mats, + cache_path=cache_path, + ) + + + mode = "decode" if seq_len == 1 else "prefill" + + + all_tests_pass, all_pccs = True, [] + if mode == "prefill": + generation_start_pos = 0 + generation_length = 1 + else: + generation_start_pos = UNIT_TEST_START_POS + generation_length = UNIT_TEST_GENERATION_LENGTH + + + for i in range(generation_length): + # Prepare input + pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len)) + pt_inp = hugging_face_reference_model.tok_embeddings(pt_inp_ids) + pt_inp_normed = hugging_face_reference_model.layers[UNIT_TEST_LAYER_NUM].attention_norm(pt_inp) + tt_input = pt_inp_normed.clone() + start_pos = generation_start_pos + i + + + # PyTorch output -------------------------------------------------------------------- + if mode == "prefill": + attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs_prefill( + pt_inp_normed, start_pos + ) + else: + attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs( + pt_inp_normed, start_pos + ) + + + pytorch_out = pytorch_LlamaAttention_model( + attention_input, + start_pos, + freqs_cis, + attn_mask, + ) + + + # TT hardware execution ------------------------------------------------------------- + attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs( + tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta, mode=mode + ) + tt_out = tt_LlamaAttention_model( + attention_input, + rot_mat, + start_pos, + attn_mask, + mode=mode, + ) + # tt_out = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(tt_out.cpu())] + + + tt_out = ttnn.to_torch( + tt_out, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(3, 1), cluster_shape=cluster_shape) + ) + tt_out = tt_out[:, 0:1, :, :] + tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] + + + does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) + logger.info(f"Output: {output_pcc}") + + + all_pccs.append(extract_pcc_from_log(output_pcc)) + + + if does_pass: + logger.info(f"[start_pos={start_pos}] {llama_version} Attention output Passed!") + else: + logger.warning( + f"[start_pos={start_pos}] {llama_version} Attention output Failed! PCC value is lower than {pcc}" + ) + all_tests_pass = False + + + logger.info(f"Average PCC over {len(all_pccs)} tokens: {sum(all_pccs) / len(all_pccs)}") + + + # Check kv cache + # PyTorch output -------------------------------------------------------------------- + pytorch_layer_present = [ + pytorch_LlamaAttention_model.attention.cache_k.clone().permute(0, 2, 1, 3)[ + :batch, ... + ], # [batch, n_kv_heads, seq, head_dim] + pytorch_LlamaAttention_model.attention.cache_v.clone().permute(0, 2, 1, 3)[ + :batch, ... + ], # [batch, n_kv_heads, seq, head_dim] + ] + # TT hardware output ---------------------------------------------------------------- + + + # concat the pasts by heads + tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] + + + tt_layer_present_all = [ + ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[ + :batch, ... + ] + for lp in tt_layer_present_all + ] + + + cache_test_pass = check_kv_cache( + pytorch_layer_present, + tt_layer_present_all, + generation_start_pos, + generation_length, + seq_len, + mode == "prefill", + pcc, + ) + + + all_tests_pass = all_tests_pass and cache_test_pass + + + if all_tests_pass: + logger.info(f"{llama_version} Attention output Passed!") + else: + gc.collect() + logger.warning(f"{llama_version} Attention output Failed!") + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" + + @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "cluster_shape, mesh_device", [pytest.param((4, 8), (8, 4), id="4x8_grid")], indirect=["mesh_device"] + "cluster_shape, mesh_device", [pytest.param((4, 8), (8, 4), id="4x8_grid")], indirect=["mesh_device"] ) @pytest.mark.parametrize( - "llama_version", - (("llama3-tg"),), + "llama_version", + (("llama3-tg"),), ) @pytest.mark.parametrize( - "batch, seq_len, pcc", - [ - (32, 1, 0.9995), - (1, 256, 0.999), - ], - ids=[ - "decode", - "prefill", - ], + "batch, seq_len, pcc", + [ + (32, 1, 0.9995), + (1, 256, 0.999), + ], + ids=[ + "decode", + "prefill", + ], ) @pytest.mark.parametrize( - "max_batch_size, max_context_len", - ( - (32, 2048), - # (16, 8192), - ), - ids=( - "short_context", - # "long_context", - ), + "max_batch_size, max_context_len", + ( + (32, 2048), + # (16, 8192), + ), + ids=( + "short_context", + # "long_context", + ), ) def test_LlamaAttention_inference( - batch, - seq_len, - pcc, - mesh_device, - max_batch_size, - max_context_len, - llama_version, - cluster_shape, - use_program_cache, + batch, + seq_len, + pcc, + mesh_device, + max_batch_size, + max_context_len, + llama_version, + cluster_shape, + use_program_cache, ): - if batch > max_batch_size: - pytest.skip(f"Decode with {batch} users is not supported with large context") - - if batch == 1 and seq_len > max_context_len: - pytest.skip(f"Prefill with {seq_len=} is not supported with short context") - - if llama_version == "llama2" and seq_len > 2048: - pytest.skip(f"Llama2 with {seq_len=} is not supported (max 2048)") - - model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( - llama_version=llama_version, - max_batch_size=max_batch_size, - max_context_len=max_context_len, - ) - check_mesh_device(mesh_device, model_config) - run_test_LlamaAttention_inference( - mesh_device, - cluster_shape, - batch, - seq_len, - pcc, - model_config, - llama_version, - ckpt_dir, - tokenizer_path, - cache_path, - ) + if batch > max_batch_size: + pytest.skip(f"Decode with {batch} users is not supported with large context") + + + if batch == 1 and seq_len > max_context_len: + pytest.skip(f"Prefill with {seq_len=} is not supported with short context") + + + if llama_version == "llama2" and seq_len > 2048: + pytest.skip(f"Llama2 with {seq_len=} is not supported (max 2048)") + + + model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( + llama_version=llama_version, + max_batch_size=max_batch_size, + max_context_len=max_context_len, + ) + check_mesh_device(mesh_device, model_config) + run_test_LlamaAttention_inference( + mesh_device, + cluster_shape, + batch, + seq_len, + pcc, + model_config, + llama_version, + ckpt_dir, + tokenizer_path, + cache_path, + ) diff --git a/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py index 1c48eb04d89..29dddca1698 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy @@ -33,7 +33,6 @@ check_kv_cache, num_to_corerange, ConcatMesh2DToTensor, - ShardTensor2dMesh, ) import gc @@ -129,8 +128,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): layout=ttnn.TILE_LAYOUT, device=llama_decoder_model.mesh_device, memory_config=ACT_MEMCFG, - mesh_mapper=ShardTensor2dMesh( - llama_decoder_model.mesh_device, dims=(3, None), cluster_shape=llama_decoder_model.cluster_shape + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_decoder_model.mesh_device, mesh_shape=llama_decoder_model.cluster_shape, dims=(None, 3) ), ) @@ -159,7 +158,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): layout=ttnn.TILE_LAYOUT, memory_config=ROT_MAT_MEMCFG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) attn_masks = None @@ -173,8 +172,8 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): layout=ttnn.TILE_LAYOUT, device=llama_decoder_model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh( - llama_decoder_model.mesh_device, dims=(3, None), cluster_shape=llama_decoder_model.cluster_shape + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_decoder_model.mesh_device, cluster_shape=llama_decoder_model.cluster_shape, dims=(None, 3) ), ) @@ -196,7 +195,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -205,7 +204,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -218,7 +217,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, ) @@ -273,7 +272,7 @@ def run_test_LlamaDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ) tt_LlamaDecoder_model = TtLlamaDecoder_galaxy( diff --git a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py index 62abd01a8ec..6346cd604eb 100644 --- a/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py @@ -6,9 +6,8 @@ import math import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper, replicate_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.tt.llama_common import ( - ShardTensor2dMesh, ConcatMesh2DToTensor, ) from models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -91,7 +90,7 @@ def get_slice_mat(self): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) def get_user_selection_mat(self): @@ -104,7 +103,7 @@ def get_user_selection_mat(self): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) def init_kv_cache(self): @@ -133,7 +132,7 @@ def init_kv_cache(self): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat8_b, @@ -206,7 +205,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, cluster_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / wqkv_cache_str, ) @@ -216,7 +215,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, 2), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, cluster_shape=self.cluster_shape, dims=(2, 3)), cache_file_name=self.cache_path / wo_cache_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py index 5c6e1c64ef2..5d3fae48334 100644 --- a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py @@ -7,13 +7,11 @@ from models.demos.tg.llama3_70b.tt.llama_attention_galaxy import TtLlamaAttention_galaxy from models.demos.tg.llama3_70b.tt.llama_mlp_galaxy import TtLlamaMLP_galaxy -from models.demos.t3000.llama2_70b.tt.llama_common import ( - ShardTensor2dMesh, -) from models.demos.tg.llama3_70b.tt.llama_common import ( tt_sharded_distributed_rmsnorm, tt_distributed_rmsnorm, ) +from ttnn import shard_tensor_to_2d_mesh_mapper class TtLlamaDecoder_galaxy: @@ -112,7 +110,8 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, (2, None), self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper +(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / attn_norm_sharded_str, ) @@ -122,7 +121,8 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, (2, None), self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper +(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / ffn_norm_sharded_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py index d76abe350f2..46d49eee42f 100644 --- a/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_embedding_galaxy.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import ttnn -from models.demos.t3000.llama2_70b.tt.llama_common import ShardTensor2dMesh +from ttnn import shard_tensor_to_2d_mesh_mapper class TtLlamaEmbedding_galaxy: @@ -28,7 +28,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, None), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(None, 3)), cache_file_name=cache_path / embedding_cache_name, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index c876713ce9f..de8c1b3d11d 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -4,13 +4,13 @@ from typing import List import ttnn -from models.demos.t3000.llama2_70b.tt.llama_common import ShardTensor2dMesh, ConcatMesh2DToTensor +from models.demos.t3000.llama2_70b.tt.llama_common import ConcatMesh2DToTensor from models.utility_functions import nearest_32 from models.demos.tg.llama3_70b.tt.llama_common import tt_all_reduce, tt_composite_sharded_all_reduce from models.demos.t3000.falcon40b.tt.model_utils import ( matmul_2d_config_from_tensor_shapes as get_matmul_2d_config_from_tensor_shapes, ) - +from ttnn import shard_tensor_to_2d_mesh_mapper class TtLlamaMLP_galaxy: def __init__( @@ -79,7 +79,7 @@ def load_weights(self): device=self.mesh_device, # memory_config=self.w1_mem_config, # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / w1_cache_str, ) @@ -90,7 +90,7 @@ def load_weights(self): device=self.mesh_device, # memory_config=self.mlp_config["W1_MEM_CONFIG"](self.mesh_device, self.cluster_shape), # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / w3_cache_str, ) @@ -101,7 +101,7 @@ def load_weights(self): device=self.mesh_device, # memory_config=self.mlp_config["W2_MEM_CONFIG"](self.mesh_device), # TODO: Reenable when DRAM-SHARDED PCC issues resolves memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(3, 2), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(2, 3)), cache_file_name=self.cache_path / w2_cache_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py index b309872779b..429673467cf 100644 --- a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_2d_mesh_mapper from models.demos.tg.llama3_70b.tt.llama_decoder_galaxy import TtLlamaDecoder_galaxy from models.demos.tg.llama3_70b.tt.llama_embedding_galaxy import TtLlamaEmbedding_galaxy from models.demos.t3000.llama2_70b.tt.llama_common import ( @@ -17,7 +17,6 @@ get_rot_transformation_mat, num_to_corerange, gather_cos_sin, - ShardTensor2dMesh, ) from models.demos.tg.llama3_70b.tt.llama_common import ( tt_all_reduce, @@ -25,7 +24,7 @@ tt_sharded_distributed_rmsnorm, tt_distributed_rmsnorm, ) - +from ttnn import shard_tensor_to_2d_mesh_mapper def is_power_of_two(n): if n <= 0: @@ -74,7 +73,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ) logger.info("Creating Layers") @@ -142,7 +141,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, dims=(2, 3), cluster_shape=self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(3, 2)), cache_file_name=self.cache_path / lm_head_cache_str, ) @@ -152,7 +151,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensor2dMesh(self.mesh_device, (2, None), self.cluster_shape), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, mesh_shape=self.cluster_shape, dims=(None, 2)), cache_file_name=self.cache_path / norm_sharded_cache_str, ) @@ -173,7 +172,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) xs = self.tt_embd(x) @@ -226,7 +225,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, device=self.mesh_device, cache_file_name=cache_name(f"rot_mat_decode_galaxy_{start_pos}"), memory_config=ROT_MAT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) attn_masks = None @@ -247,7 +246,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, # cache_file_name=cache_name(f"cos_gathered_prefill_galaxy_{start_pos}"), device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -256,7 +255,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, # cache_file_name=cache_name(f"sin_gathered_prefill_galaxy_{start_pos}"), device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -269,7 +268,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(self.mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=self.mesh_device, ) From 6820794bbfcdf1821f65c5ca925d8bbdbd9e0e77 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 16:31:47 +0000 Subject: [PATCH 69/76] switch replicate --- models/common/rmsnorm.py | 2 +- models/common/tests/test_rmsnorm.py | 4 +- .../demos/falcon7b_common/tests/test_utils.py | 6 +-- .../falcon7b_common/tt/falcon_attention.py | 8 ++-- .../falcon7b_common/tt/falcon_causallm.py | 4 +- models/demos/falcon7b_common/tt/falcon_mlp.py | 8 ++-- .../demos/falcon7b_common/tt/falcon_model.py | 8 ++-- .../demos/falcon7b_common/tt/model_utils.py | 4 +- .../multimodal/test_llama_class_embedding.py | 4 +- .../multimodal/test_llama_conv2d_patch.py | 2 +- .../multimodal/test_llama_cross_attention.py | 8 ++-- ..._llama_cross_attention_transformer_text.py | 12 +++--- .../multimodal/test_llama_cross_block.py | 10 ++--- .../multimodal/test_llama_image_attention.py | 2 +- .../multimodal/test_llama_image_block.py | 2 +- .../tests/multimodal/test_llama_image_mlp.py | 2 +- .../test_llama_image_transformer.py | 2 +- .../tests/multimodal/test_llama_layernorm.py | 2 +- .../test_llama_positional_embedding.py | 4 +- .../test_llama_tile_position_embedding.py | 4 +- .../tests/test_llama_attention_prefill.py | 4 +- .../tests/test_llama_decoder_prefill.py | 4 +- .../llama3/tests/test_llama_embedding.py | 2 +- models/demos/llama3/tests/test_llama_mlp.py | 2 +- .../llama3/tests/test_llama_model_prefill.py | 2 +- models/demos/llama3/tt/llama_attention.py | 4 +- models/demos/llama3/tt/llama_common.py | 12 +++--- models/demos/llama3/tt/llama_model.py | 8 ++-- models/demos/llama3/tt/llama_rope.py | 12 +++--- .../tt/multimodal/llama_class_embedding.py | 4 +- .../tt/multimodal/llama_conv2d_patch.py | 8 ++-- ...lama_cross_attention_transformer_vision.py | 2 +- .../llama3/tt/multimodal/llama_image_block.py | 4 +- .../llama3/tt/multimodal/llama_image_mlp.py | 2 +- .../llama3/tt/multimodal/llama_layernorm.py | 4 +- .../multimodal/llama_positional_embedding.py | 8 ++-- .../llama_tile_position_embedding.py | 6 +-- .../tt/multimodal/llama_vision_encoder.py | 4 +- .../tt/multimodal/llama_vision_model.py | 22 +++++------ models/demos/qwen/demo/demo.py | 12 +++--- models/demos/qwen/tests/test_lm_head.py | 2 +- .../demos/qwen/tests/test_qwen_attention.py | 2 +- models/demos/qwen/tests/test_qwen_decoder.py | 2 +- .../demos/qwen/tests/test_qwen_embedding.py | 2 +- models/demos/qwen/tests/test_qwen_mlp.py | 2 +- models/demos/qwen/tests/test_qwen_model.py | 2 +- models/demos/qwen/tests/test_qwen_perf.py | 4 +- models/demos/qwen/tt/model_config.py | 4 +- models/demos/qwen/tt/qwen_common.py | 8 ++-- .../falcon40b/tests/test_falcon_attention.py | 8 ++-- .../t3000/falcon40b/tests/test_falcon_mlp.py | 2 +- .../t3000/falcon40b/tt/falcon_attention.py | 10 ++--- .../t3000/falcon40b/tt/falcon_decoder.py | 10 ++--- models/demos/t3000/falcon40b/tt/falcon_mlp.py | 4 +- .../demos/t3000/falcon40b/tt/falcon_model.py | 12 +++--- .../demos/t3000/falcon40b/tt/model_utils.py | 2 +- ...emo_continuous_batching_paged_attention.py | 4 +- .../tests/test_chunked_generation.py | 2 +- .../llama2_70b/tests/test_llama_attention.py | 20 +++++----- .../llama2_70b/tests/test_llama_decoder.py | 10 ++--- .../llama2_70b/tests/test_llama_generation.py | 2 +- .../llama2_70b/tests/test_llama_model.py | 2 +- .../llama2_70b/tt/llama_decoder_optimized.py | 6 +-- .../t3000/llama2_70b/tt/llama_generation.py | 2 +- .../llama2_70b/tt/llama_model_optimized.py | 20 +++++----- .../demos/t3000/llama2_70b/tt/llama_rope.py | 10 ++--- models/demos/t3000/mixtral8x7b/demo/demo.py | 2 +- .../mixtral8x7b/demo/demo_with_prefill.py | 4 +- .../tests/test_mixtral_attention.py | 2 +- .../tests/test_mixtral_attention_prefill.py | 4 +- .../mixtral8x7b/tests/test_mixtral_decoder.py | 2 +- .../tests/test_mixtral_decoder_prefill.py | 4 +- .../mixtral8x7b/tests/test_mixtral_mlp.py | 4 +- .../tests/test_mixtral_mlp_prefill.py | 4 +- .../mixtral8x7b/tests/test_mixtral_model.py | 2 +- .../tests/test_mixtral_model_prefill.py | 4 +- .../mixtral8x7b/tests/test_mixtral_moe.py | 4 +- .../tests/test_mixtral_moe_prefill.py | 4 +- .../mixtral8x7b/tests/test_mixtral_perf.py | 4 +- .../tests/test_mixtral_perplexity.py | 2 +- .../tests/test_mixtral_rms_norm.py | 4 +- .../t3000/mixtral8x7b/tt/mixtral_attention.py | 4 +- .../t3000/mixtral8x7b/tt/mixtral_common.py | 26 ++++++------- .../t3000/mixtral8x7b/tt/mixtral_model.py | 6 +-- .../demos/t3000/mixtral8x7b/tt/mixtral_moe.py | 8 ++-- .../tests/multi_chip/test_falcon_attention.py | 4 +- .../tests/multi_chip/test_falcon_causallm.py | 8 ++-- .../tests/multi_chip/test_falcon_decoder.py | 4 +- .../tests/multi_chip/test_falcon_mlp.py | 4 +- .../tests/multi_chip/test_falcon_model.py | 4 +- models/demos/ttnn_falcon7b/tt/falcon_model.py | 2 +- .../ttnn_resnet/tests/resnet50_test_infra.py | 2 +- models/demos/wormhole/bert_tiny/demo/demo.py | 4 +- .../bert_tiny/tests/test_performance.py | 4 +- models/demos/wormhole/distilbert/demo/demo.py | 8 ++-- .../distilbert/tests/test_perf_distilbert.py | 4 +- .../tests/test_unet_bottleneck.py | 2 +- .../tests/test_unet_downblock.py | 2 +- .../tests/test_unet_multi_device.py | 2 +- .../functional_unet/tests/test_unet_trace.py | 4 +- .../tests/test_unet_upblock.py | 2 +- models/experimental/grok/demo/demo.py | 6 +-- .../grok/tests/test_grok_decoder.py | 2 +- .../experimental/grok/tests/test_grok_mlp.py | 4 +- .../experimental/grok/tests/test_grok_moe.py | 4 +- .../grok/tests/test_grok_rms_norm.py | 6 +-- models/experimental/grok/tt/grok_attention.py | 4 +- models/experimental/grok/tt/grok_common.py | 12 +++--- models/experimental/grok/tt/grok_moe.py | 10 ++--- models/experimental/grok/tt/grok_rms_norm.py | 6 +-- tech_reports/CNNs/cnn_optimizations.md | 2 +- tech_reports/LLMs/llms.md | 2 +- .../Programming_Mesh_of_Devices_with_TT-NN.md | 4 +- .../distributed/test_data_parallel_example.py | 2 +- .../test_data_parallel_example_TG.py | 2 +- tests/ttnn/distributed/test_multidevice_TG.py | 38 +++++++++---------- .../test_tensor_parallel_example_T3000.py | 2 +- .../bert_tiny/test_bert_tiny_wh.py | 12 +++--- .../distilbert/test_ttnn_distilbert_wh.py | 4 +- .../operations/prefetcher_common.py | 4 +- .../unit_tests/operations/test_new_conv2d.py | 2 +- .../tensor/test_tensor_prealloc_and_write.py | 2 +- tests/ttnn/unit_tests/test_multi_device.py | 34 ++++++++--------- .../unit_tests/test_multi_device_async.py | 18 ++++----- .../unit_tests/test_multi_device_events.py | 2 +- .../unit_tests/test_multi_device_trace.py | 4 +- .../unit_tests/test_multi_device_trace_TG.py | 4 +- .../unit_tests/test_multi_device_trace_tgg.py | 4 +- tests/ttnn/unit_tests/test_reshape.py | 2 +- 129 files changed, 365 insertions(+), 365 deletions(-) diff --git a/models/common/rmsnorm.py b/models/common/rmsnorm.py index 28eb9cadf55..6926df48f7a 100644 --- a/models/common/rmsnorm.py +++ b/models/common/rmsnorm.py @@ -80,7 +80,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None, ) if self.is_distributed: diff --git a/models/common/tests/test_rmsnorm.py b/models/common/tests/test_rmsnorm.py index 1828a6702e4..1933b0798ab 100644 --- a/models/common/tests/test_rmsnorm.py +++ b/models/common/tests/test_rmsnorm.py @@ -12,7 +12,7 @@ os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.rmsnorm import RMSNorm as TtRMSNorm from models.utility_functions import ( @@ -130,7 +130,7 @@ def test_rmsnorm_multidevice(t3k_mesh_device, is_sharded, use_program_cache, res device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + mesh_mapper=replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/demos/falcon7b_common/tests/test_utils.py b/models/demos/falcon7b_common/tests/test_utils.py index 3e7e29fe478..076d64500e6 100644 --- a/models/demos/falcon7b_common/tests/test_utils.py +++ b/models/demos/falcon7b_common/tests/test_utils.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper from transformers import FalconForCausalLM from models.utility_functions import tt_tensors_to_torch_tensors @@ -20,14 +20,14 @@ def initialize_kv_cache(configuration, num_layers, batch_size, max_seq_len, mesh dtype=ttnn.bfloat16, device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_v_cache = tt_from_torch( v_cache, dtype=ttnn.bfloat16, device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) kv_cache += ((tt_k_cache, tt_v_cache),) return kv_cache diff --git a/models/demos/falcon7b_common/tt/falcon_attention.py b/models/demos/falcon7b_common/tt/falcon_attention.py index ea1c2740148..54af7c56102 100644 --- a/models/demos/falcon7b_common/tt/falcon_attention.py +++ b/models/demos/falcon7b_common/tt/falcon_attention.py @@ -9,7 +9,7 @@ from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.utility_functions import ( nearest_32, @@ -155,7 +155,7 @@ def __init__( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # optimized version can utilize single float value for softmax @@ -175,7 +175,7 @@ def __init__( device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["ATTN_OPTIMIZED_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.model_config["ATTN_OUTPUT_TENSORS"][seq_len] = tt_tensors @@ -553,7 +553,7 @@ def __init__( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward( diff --git a/models/demos/falcon7b_common/tt/falcon_causallm.py b/models/demos/falcon7b_common/tt/falcon_causallm.py index 09702d4a94f..6194690c3d0 100644 --- a/models/demos/falcon7b_common/tt/falcon_causallm.py +++ b/models/demos/falcon7b_common/tt/falcon_causallm.py @@ -6,7 +6,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.falcon_lm_head import falcon_lm_head_matmul_2d from models.demos.falcon7b_common.tt.falcon_model import TtFalconModelShared from models.demos.falcon7b_common.tt.model_utils import ( @@ -123,7 +123,7 @@ def __init__( device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["LM_HEAD_MM_INPUT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.lm_head_weights = get_weights_cached( diff --git a/models/demos/falcon7b_common/tt/falcon_mlp.py b/models/demos/falcon7b_common/tt/falcon_mlp.py index 7694e2d4ea8..d6884d7b59e 100644 --- a/models/demos/falcon7b_common/tt/falcon_mlp.py +++ b/models/demos/falcon7b_common/tt/falcon_mlp.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.model_utils import ( get_falcon_default_core_grid, get_weights_cached, @@ -176,7 +176,7 @@ def _load_mlp_padded_tensors(self): device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) mlp_padding_tensors[seq_len] = tt_padding self.model_config["MLP_PREFILL_PADDING_TENSORS"] = mlp_padding_tensors @@ -191,7 +191,7 @@ def _allocate_output_mlp_tensors(self): device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.model_config["MLP_OUTPUT_TENSORS"] = out_tt @@ -344,7 +344,7 @@ def _load_mlp_padded_tensors(self): device=self.mesh_device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.model_config["MLP_DECODE_PADDING_TENSORS"] = tt_paddings diff --git a/models/demos/falcon7b_common/tt/falcon_model.py b/models/demos/falcon7b_common/tt/falcon_model.py index fa2932cb0c8..b8c9a50423b 100644 --- a/models/demos/falcon7b_common/tt/falcon_model.py +++ b/models/demos/falcon7b_common/tt/falcon_model.py @@ -7,7 +7,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensorToMesh from models.demos.falcon7b_common.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.falcon7b_common.tt.model_utils import get_weights_cached, layernorm @@ -134,7 +134,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) for attention_mask_slice in attention_mask_ ] @@ -156,7 +156,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Repeat attn masks for all heads tt_attention_mask = ttnn.repeat( @@ -210,7 +210,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=self.model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if not self.model_config["l1_sharded"]: # Tilize attn masks diff --git a/models/demos/falcon7b_common/tt/model_utils.py b/models/demos/falcon7b_common/tt/model_utils.py index b7ce657bd69..3bf7dc0919d 100644 --- a/models/demos/falcon7b_common/tt/model_utils.py +++ b/models/demos/falcon7b_common/tt/model_utils.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.utility_functions import is_wormhole_b0 @@ -50,7 +50,7 @@ def preprocess_weights(weights_to_cache): layout=tt_layout, device=mesh_device, memory_config=model_config[f"{weight_config_str}_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device) if type(mesh_device) == ttnn.MeshDevice else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if type(mesh_device) == ttnn.MeshDevice else None, cache_file_name=str(path), preprocess=preprocess_weights, ) diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index dc395842338..7f97f631ef6 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -15,7 +15,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, @@ -108,7 +108,7 @@ def test_llama_class_embedding_inference( layout=layout, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) logger.info(f"TT Input tensor shape: {tt_input_tensor.shape}") diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index c38dd5ccb26..520c2b30cff 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -14,7 +14,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index 15490b6ba41..d4eedaaf744 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -170,7 +170,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask = ttnn.from_torch( full_text_mask_expand[b : b + 1], @@ -178,7 +178,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model( tt_tensor_x, @@ -209,7 +209,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -224,7 +224,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask = ttnn.reshape( tt_full_text_mask, diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 7c59a9630de..16b85b2b220 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -193,7 +193,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH[b : b + 1], @@ -201,7 +201,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD[b : b + 1], @@ -253,7 +253,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats, _ = get_single_rot_mat( @@ -275,7 +275,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -290,7 +290,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, @@ -309,7 +309,7 @@ def test_llama_cross_attention_transformer_text_inference( device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.to_layout(tt_full_text_mask_expand_11SD, ttnn.TILE_LAYOUT) diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 7516354af66..0d132fb5191 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -161,7 +161,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.from_torch( full_text_mask_expand_1NSH[b : b + 1], @@ -169,7 +169,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.from_torch( full_text_mask_expand_11SD[b : b + 1], @@ -207,7 +207,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_xattn_mask = ttnn.reshape( tt_xattn_mask, @@ -222,7 +222,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.reshape( tt_full_text_mask_expand_1NSH, @@ -241,7 +241,7 @@ def test_llama_cross_attention_transformer_block_inference( device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_full_text_mask_expand_11SD = ttnn.to_layout(tt_full_text_mask_expand_11SD, ttnn.TILE_LAYOUT) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py index 3d9e6977145..03be0a437a9 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_attention.py @@ -96,7 +96,7 @@ def test_llama_attention_inference(batch, num_chunks, mesh_device, use_program_c dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model(attention_input, mask=tt_mask) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_block.py b/models/demos/llama3/tests/multimodal/test_llama_image_block.py index 23096202e29..f21ad59bda2 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_block.py @@ -104,7 +104,7 @@ def test_llama_block_inference(batch, num_chunks, mesh_device, gated, use_progra dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model(attention_input, mask=tt_mask) diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py index c6b65ef7f9d..8013df2f2da 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_mlp.py @@ -75,7 +75,7 @@ def test_llama_mlp_inference(batch, num_chunks, mesh_device, use_program_cache, tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py index 502736ac790..5a22ffb3b84 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py +++ b/models/demos/llama3/tests/multimodal/test_llama_image_transformer.py @@ -135,7 +135,7 @@ def test_llama_image_transformer_inference( dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) with torch.no_grad(): diff --git a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py index d52d9f415f3..56daf2540c0 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_layernorm.py +++ b/models/demos/llama3/tests/multimodal/test_llama_layernorm.py @@ -74,7 +74,7 @@ def test_layernorm_inference(mesh_device, use_program_cache, reset_seeds, ensure tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index c5262bf2235..cad1136dd1c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -17,7 +17,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, @@ -128,7 +128,7 @@ def test_llama_positional_embedding_inference( layout=layout, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_input_tensor = ttnn.to_layout(tt_input_tensor, ttnn.ROW_MAJOR_LAYOUT) diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 4ba64dd76ff..97517e1178b 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -17,7 +17,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, @@ -98,7 +98,7 @@ def test_llama_conv2d_inference( layout=layout, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) logger.info(f"TT Input tensor shape: {tt_input_tensor.shape}") diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index 52d6e2cc19a..534ffa2c407 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -100,7 +100,7 @@ def test_llama_attention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) transformation_mats = {"prefill": transformation_mats_prefill} @@ -129,7 +129,7 @@ def test_llama_attention_inference( device=mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_model = TtLlamaAttention( diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index a370011383d..2e0c9551054 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -102,7 +102,7 @@ def test_llama_decoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) transformation_mats = {"prefill": transformation_mats_prefill} @@ -127,7 +127,7 @@ def test_llama_decoder_inference( device=mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Initialize TT model diff --git a/models/demos/llama3/tests/test_llama_embedding.py b/models/demos/llama3/tests/test_llama_embedding.py index 71d56a3a7f4..2b28a51944b 100644 --- a/models/demos/llama3/tests/test_llama_embedding.py +++ b/models/demos/llama3/tests/test_llama_embedding.py @@ -67,7 +67,7 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache tt_input = ttnn.from_torch( pt_input.squeeze(1), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, ) diff --git a/models/demos/llama3/tests/test_llama_mlp.py b/models/demos/llama3/tests/test_llama_mlp.py index 710ee9498c5..37770024ce1 100644 --- a/models/demos/llama3/tests/test_llama_mlp.py +++ b/models/demos/llama3/tests/test_llama_mlp.py @@ -75,7 +75,7 @@ def test_llama_mlp_inference(seq_len, batch_size, mesh_device, use_program_cache device=mesh_device, mesh_mapper=ttnn.ShardTensor2dMesh( mesh_device, dims=(None, 3) if model_args.is_galaxy else (None, None), mesh_shape=model_args.cluster_shape - ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` + ), # When both dims are None, the mapper used is `ttnn.replicate_tensor_to_mesh_mapper` dtype=ttnn.bfloat8_b, memory_config=( ( diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index 667764a2304..6e6bfcca2e3 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -160,7 +160,7 @@ def test_llama_model_inference( device=mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Load TTNN model diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index d1c1bee93b0..9afd6c45738 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -74,7 +74,7 @@ def __init__( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.dtype = dtype @@ -277,7 +277,7 @@ def init_kv_cache(self, configuration, weight_cache_path): layout=self.model_config["ATTN_W_LAYOUT_TILE"], device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=( f"{weight_cache_path}/kvcache_{k_or_v.shape}" if weight_cache_path and not configuration.dummy_weights diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index dd6873ed8b3..7ec888fa9b3 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -219,14 +219,14 @@ def get_prefill_rot_mat(head_dim, mesh_device, seq_len, theta, scale_factor, ori dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) sin_gathereds = ttnn.from_torch( sin_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -280,13 +280,13 @@ def get_single_rot_mat( device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ) @@ -402,7 +402,7 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32, device=None, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if mesh_device.get_num_devices() > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if mesh_device.get_num_devices() > 1 else None, ), pt_out, ) @@ -413,7 +413,7 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), pt_out, ) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 8f49cd04299..1f473a473dc 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -113,7 +113,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag device=self.mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tokens_embd = self.embd(tokens) tokens_embd = ttnn.unsqueeze_to_4D(tokens_embd) @@ -127,7 +127,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag device=self.mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: tt_page_table = None @@ -138,7 +138,7 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag device=self.mesh_device, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: tt_chunk_page_table = None @@ -172,7 +172,7 @@ def prepare_decode_inputs_host(self, tokens, current_pos, page_table=None): tokens, device=None, dtype=ttnn.uint32, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tokens = ttnn.unsqueeze_to_4D(tokens) diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index 533768df5b5..3a4a414ca5f 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensor2dMesh +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensor2dMesh from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.llama_common import precompute_freqs, get_rot_transformation_mat, gather_cos_sin from models.utility_functions import nearest_32 @@ -56,14 +56,14 @@ def __init__( device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) self.sin_matrix = ttnn.from_torch( sin_matrix, device=device, layout=ttnn.TILE_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) batch_grid = ttnn.num_cores_to_corerangeset(batch_size, self.core_grid, row_wise=True) @@ -107,7 +107,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, dtype=datatype, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) def get_both_trans_mats(self): @@ -133,7 +133,7 @@ def get_rot_idxs(self, position_idxs, on_host=False): position_idxs, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(self.device) if self.is_mesh_device else None, ) else: # On device rot_idxs = ttnn.as_tensor( @@ -142,7 +142,7 @@ def get_rot_idxs(self, position_idxs, on_host=False): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(self.device) if self.is_mesh_device else None, ) return rot_idxs diff --git a/models/demos/llama3/tt/multimodal/llama_class_embedding.py b/models/demos/llama3/tt/multimodal/llama_class_embedding.py index 6bb57822953..fd3d8defe4c 100644 --- a/models/demos/llama3/tt/multimodal/llama_class_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_class_embedding.py @@ -8,7 +8,7 @@ import ttnn from models.common.lightweightmodule import LightweightModule -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper class TtLlamaClassEmbedding(LightweightModule): @@ -37,7 +37,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def forward(self, x): diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index f5ff04f7e3e..2da85f97e33 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -11,7 +11,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper class TtLlamaConv2dPatch(LightweightModule): @@ -56,7 +56,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if bias else None @@ -76,7 +76,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) self.compute_kernel_config = ttnn.init_device_compute_kernel_config( @@ -102,7 +102,7 @@ def forward(self, x: torch.Tensor): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) out = ttnn.linear( diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py index 441ccda766b..f06014218fb 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py @@ -76,7 +76,7 @@ def shuffle_weight(weight): mesh_mapper=( ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) if dim is not None - else ttnn.ReplicateTensorToMesh(self.mesh_device) + else ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) ), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/multimodal/llama_image_block.py b/models/demos/llama3/tt/multimodal/llama_image_block.py index 9ab361aed26..257ee5763c6 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_block.py +++ b/models/demos/llama3/tt/multimodal/llama_image_block.py @@ -79,7 +79,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_attn"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -87,7 +87,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_ffn"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/llama3/tt/multimodal/llama_image_mlp.py b/models/demos/llama3/tt/multimodal/llama_image_mlp.py index 0d56f310eaf..2c085b834d4 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_mlp.py +++ b/models/demos/llama3/tt/multimodal/llama_image_mlp.py @@ -43,7 +43,7 @@ def __init__( mesh_mapper=( ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) if dim is not None - else ttnn.ReplicateTensorToMesh(self.mesh_device) + else ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) ), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/multimodal/llama_layernorm.py b/models/demos/llama3/tt/multimodal/llama_layernorm.py index a20c4764ad1..737b16290af 100644 --- a/models/demos/llama3/tt/multimodal/llama_layernorm.py +++ b/models/demos/llama3/tt/multimodal/llama_layernorm.py @@ -42,7 +42,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name / "weight", - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None, ) self.bias = ttnn.as_tensor( @@ -52,7 +52,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, memory_config=weight_memory_config, cache_file_name=cache_name / "bias", - mesh_mapper=ttnn.ReplicateTensorToMesh(device) if is_mesh_device else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(device) if is_mesh_device else None, ) if model_config: diff --git a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py index af80b24b862..58aab0c4157 100644 --- a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper TILE_SIZE = 32 @@ -48,7 +48,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) padded_gated_embeddings, self.ar_mapping = self.generate_padded_gated_embeddings( gated_positional_embedding, gated_positional_embedding_gate @@ -59,7 +59,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Add batch and ntok dimensions @@ -72,7 +72,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def generate_padded_gated_embeddings(self, gated_embedding, gate): diff --git a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py index 9ef2aadddac..8a1a9d44064 100644 --- a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper class TtLlamaTilePositionEmbedding(LightweightModule): @@ -56,7 +56,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if self.gated: @@ -67,7 +67,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def generate_padded_embeddings(self, embedding: torch.Tensor, num_tiles, width): diff --git a/models/demos/llama3/tt/multimodal/llama_vision_encoder.py b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py index dfe441ee039..fa8d30f6e18 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_encoder.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_encoder.py @@ -37,7 +37,7 @@ def pad_seq_one_tile(x, mesh_device): device=mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) @@ -239,7 +239,7 @@ def forward(self, images, ar): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) x = ttnn.reshape(x, (1, bsz * num_concurrent_media, -1, dim)) diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 7fc9d630102..7b0096d1b3c 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -243,7 +243,7 @@ def compute_vision_tokens_masks( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) padded_masks = _pad_masks( # torch.Size([1, 512, 1, 4]) @@ -314,7 +314,7 @@ def prepare_inputs_prefill( dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tt_xattn_mask = ttnn.to_layout(tt_xattn_mask, ttnn.TILE_LAYOUT) tt_xattn_mask = ttnn.typecast(tt_xattn_mask, ttnn.bfloat4_b) @@ -333,7 +333,7 @@ def prepare_inputs_prefill( dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) tt_full_text_mask_expand_1NSH = ttnn.to_layout(tt_full_text_mask_expand_1NSH, ttnn.TILE_LAYOUT) tt_full_text_mask_expand_1NSH = ttnn.typecast(tt_full_text_mask_expand_1NSH, ttnn.bfloat4_b) @@ -356,7 +356,7 @@ def prepare_inputs_prefill( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: assert cross_attention_masks is None and full_text_row_masked_out_mask is None @@ -385,7 +385,7 @@ def prepare_inputs_prefill( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return ( @@ -503,7 +503,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) rot_position_id = torch.maximum( @@ -535,7 +535,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) full_text_mask = torch.cat(full_text_mask, dim=1).unsqueeze(0) @@ -553,7 +553,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) full_text_mask_expand_11SD = full_text_mask @@ -569,7 +569,7 @@ def prepare_decode_inputs_host( device=None, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if isinstance(page_table, torch.Tensor): @@ -578,7 +578,7 @@ def prepare_decode_inputs_host( page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if isinstance(cross_page_table, torch.Tensor): @@ -587,7 +587,7 @@ def prepare_decode_inputs_host( cross_page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return ( diff --git a/models/demos/qwen/demo/demo.py b/models/demos/qwen/demo/demo.py index 3474877333d..a076904b1ed 100644 --- a/models/demos/qwen/demo/demo.py +++ b/models/demos/qwen/demo/demo.py @@ -283,7 +283,7 @@ def run_qwen_demo( dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) profiler.end(f"prepare_rot_mat_for_prefill", iteration=batch_idx) @@ -371,7 +371,7 @@ def run_qwen_demo( torch.tensor([start_pos]), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_model(pt_decode_input, current_pos_tensor, rot_mat=current_rot_mat) @@ -389,7 +389,7 @@ def run_qwen_demo( tt_out_tok = ttnn.from_torch( torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, ) profiler.end(f"prepare_first_decode_token_{batch_idx}") @@ -419,7 +419,7 @@ def run_qwen_demo( current_pos = ttnn.from_torch( torch.tensor(decoding_pos, dtype=torch.int32), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.int32, ) @@ -467,12 +467,12 @@ def run_qwen_demo( current_pos_reset = ttnn.from_torch( torch.tensor(decoding_pos, dtype=torch.int32), dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if tt_model.args.num_devices > 1 else None, ) tt_out_tok_reset = ttnn.from_torch( torch.nn.functional.pad(pt_out_batched.unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0), dtype=ttnn.uint32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if tt_model.args.num_devices > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if tt_model.args.num_devices > 1 else None, ) ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos) diff --git a/models/demos/qwen/tests/test_lm_head.py b/models/demos/qwen/tests/test_lm_head.py index b62acd9284b..a703fae02ba 100644 --- a/models/demos/qwen/tests/test_lm_head.py +++ b/models/demos/qwen/tests/test_lm_head.py @@ -65,7 +65,7 @@ def test_qwen_lm_head_inference(mesh_device, seq_len, use_program_cache, reset_s tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat8_b, memory_config=model_args.model_config["LM_HEAD_INPUT_MEMCFG"], layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/qwen/tests/test_qwen_attention.py b/models/demos/qwen/tests/test_qwen_attention.py index 18ec68dba7f..c47242ebce7 100644 --- a/models/demos/qwen/tests/test_qwen_attention.py +++ b/models/demos/qwen/tests/test_qwen_attention.py @@ -88,7 +88,7 @@ def test_qwen_attention_inference(mesh_device, use_program_cache, reset_seeds, e torch.tensor([current_pos] * batch), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) attention_input = model_args.prepare_inputs_ttnn_decode( diff --git a/models/demos/qwen/tests/test_qwen_decoder.py b/models/demos/qwen/tests/test_qwen_decoder.py index ff86c59320c..7095f670d9d 100644 --- a/models/demos/qwen/tests/test_qwen_decoder.py +++ b/models/demos/qwen/tests/test_qwen_decoder.py @@ -89,7 +89,7 @@ def test_qwen_decoder_inference(mesh_device, use_program_cache, reset_seeds, ens torch.tensor([current_pos] * batch), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) decode_input = model_args.prepare_inputs_ttnn_decode( diff --git a/models/demos/qwen/tests/test_qwen_embedding.py b/models/demos/qwen/tests/test_qwen_embedding.py index 1768ba78e37..e41a436a327 100644 --- a/models/demos/qwen/tests/test_qwen_embedding.py +++ b/models/demos/qwen/tests/test_qwen_embedding.py @@ -61,7 +61,7 @@ def test_qwen_embedding(mesh_device, use_program_cache, reset_seeds, ensure_gc): tt_input = ttnn.from_torch( pt_input.squeeze(1), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, ) diff --git a/models/demos/qwen/tests/test_qwen_mlp.py b/models/demos/qwen/tests/test_qwen_mlp.py index 911e79aa407..1aabd937d01 100644 --- a/models/demos/qwen/tests/test_qwen_mlp.py +++ b/models/demos/qwen/tests/test_qwen_mlp.py @@ -75,7 +75,7 @@ def test_qwen_mlp_inference(mesh_device, seq_len, use_program_cache, reset_seeds tt_input = ttnn.from_torch( torch_input, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.bfloat8_b, memory_config=model_args.model_config["SHARDED_MLP_INPUT_MEMCFG"] if mode == "decode" diff --git a/models/demos/qwen/tests/test_qwen_model.py b/models/demos/qwen/tests/test_qwen_model.py index c07b626f571..fa492e842d3 100644 --- a/models/demos/qwen/tests/test_qwen_model.py +++ b/models/demos/qwen/tests/test_qwen_model.py @@ -158,7 +158,7 @@ def test_qwen_model_inference(mesh_device, weights, layers, use_program_cache, r torch.tensor([current_pos] * batch), device=mesh_device, dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Run TT model diff --git a/models/demos/qwen/tests/test_qwen_perf.py b/models/demos/qwen/tests/test_qwen_perf.py index b1bfd92c77e..1c6a09446e7 100644 --- a/models/demos/qwen/tests/test_qwen_perf.py +++ b/models/demos/qwen/tests/test_qwen_perf.py @@ -151,7 +151,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos encoded_prompts_tensor[:, 0].unsqueeze(0).unsqueeze(0).unsqueeze(0), (0, 31), "constant", 0 ), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.uint32, ) @@ -167,7 +167,7 @@ def run_inference(tt_model, tt_embd, embd, encoded_prompts, generation_start_pos current_pos = ttnn.from_torch( torch.tensor([generation_start_pos] * batch), device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), dtype=ttnn.int32, ) diff --git a/models/demos/qwen/tt/model_config.py b/models/demos/qwen/tt/model_config.py index 8b58ce59475..09f5fbf98fb 100644 --- a/models/demos/qwen/tt/model_config.py +++ b/models/demos/qwen/tt/model_config.py @@ -503,7 +503,7 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): x: (batch, seq, dim) """ mesh_mapper = ( - ttnn.ReplicateTensorToMesh(self.mesh_device) + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) if force_replicated else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1) ) @@ -559,7 +559,7 @@ def prepare_inputs_ttnn_prefill(self, x_bsh, force_replicated=False): x_1BSH = x_bsh.unsqueeze(0) mesh_mapper = ( - ttnn.ReplicateTensorToMesh(self.mesh_device) + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) if force_replicated else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1) ) diff --git a/models/demos/qwen/tt/qwen_common.py b/models/demos/qwen/tt/qwen_common.py index b6649cce918..d48307f18e5 100644 --- a/models/demos/qwen/tt/qwen_common.py +++ b/models/demos/qwen/tt/qwen_common.py @@ -115,14 +115,14 @@ def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) sin_gathereds = ttnn.from_torch( sin_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats = [cos_gathereds, sin_gathereds] @@ -169,13 +169,13 @@ def get_single_rot_mat( device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device if not on_host else None, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device) if num_devices > 1 or not on_host else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if num_devices > 1 or not on_host else None, ) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py index a3bb1c8c386..c7d142c6a9f 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -90,7 +90,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["ATTN_INPUT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) attention_mask_memconfig = model_config["ATTN_MASK_MEMCFG"] @@ -105,7 +105,7 @@ def run_test_FalconAttention_inference( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), preprocess=lambda x: (x * (-1e5)).expand(1, 1, -1, -1), ) @@ -161,7 +161,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["LN_ATTN_OUTPUT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), preprocess=lambda x: x.unsqueeze(1).transpose(0, 2), ) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py b/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py index 1dd2eacd664..877a3143170 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_mlp.py @@ -85,7 +85,7 @@ def run_test_FalconMLP_inference( device=mesh_device, memory_config=model_config["LN_MLP_OUTPUT_MEMCFG"], layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_out = tt_FalconMLP_model(tt_mlp_input, llm_mode) diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index b3adfd184c9..e4f0385b611 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 from models.demos.t3000.falcon40b.tt.model_utils import convert_to_layout @@ -46,7 +46,7 @@ def generate_cos_sin_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["COS_CACHED_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=cos_cached_path, ) @@ -58,7 +58,7 @@ def generate_cos_sin_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["SIN_CACHED_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=sin_cached_path, ) @@ -219,7 +219,7 @@ def initialize_kvcache(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DRAM_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=kv_cache_path, ) @@ -229,7 +229,7 @@ def initialize_kvcache(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DRAM_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=kv_cache_path, ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_decoder.py b/models/demos/t3000/falcon40b/tt/falcon_decoder.py index d78b69c4aed..a17262a9043 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tt/falcon_decoder.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.falcon_attention import TtFalconAttention from models.demos.t3000.falcon40b.tt.falcon_mlp import TtFalconMLP @@ -80,7 +80,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_MLP_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_mlp_weights_path, preprocess=pad_ln_params, ) @@ -93,7 +93,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_MLP_BIAS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_mlp_bias_path, preprocess=pad_ln_params, ) @@ -111,7 +111,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_ATTN_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_attn_weights_path, preprocess=pad_ln_params, ) @@ -124,7 +124,7 @@ def pad_ln_params(x): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["LN_ATTN_BIAS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=ln_attn_bias_path, preprocess=pad_ln_params, ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index ba75a8b4a95..5ed3b36bf9a 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -8,7 +8,7 @@ from typing import List from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper class TtFalconMLP: @@ -84,7 +84,7 @@ def _allocate_output_mlp_tensors(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DEFAULT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) def __call__(self, x: List[ttnn.Tensor], llm_mode: str) -> List[ttnn.Tensor]: diff --git a/models/demos/t3000/falcon40b/tt/falcon_model.py b/models/demos/t3000/falcon40b/tt/falcon_model.py index 1c2f7b12574..c6a92cf1e4e 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_model.py +++ b/models/demos/t3000/falcon40b/tt/falcon_model.py @@ -9,7 +9,7 @@ import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensorToMesh from models.demos.t3000.falcon40b.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.t3000.falcon40b.tt.falcon_embeddings import TtFalconEmbeddings from models.demos.t3000.falcon40b.tt.falcon_attention import generate_cos_sin_cache @@ -107,7 +107,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=self.model_config["LN_F_WEIGHTS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=layernorm_weights_path, preprocess=lambda x: x.reshape(1, 1, -1, 32), ) @@ -118,7 +118,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=self.model_config["LN_F_BIAS_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=layernorm_bias_path, preprocess=lambda x: x.reshape(1, 1, -1, 32), ) @@ -138,7 +138,7 @@ def create_attn_mask(self, max_seq_len): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), preprocess=lambda x: (x * -1e5), ) @@ -181,7 +181,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Generate input and attention_mask --------------------------------------------- @@ -230,7 +230,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DEFAULT_MEMCFG"], - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(1, 1, -1, -1), ) diff --git a/models/demos/t3000/falcon40b/tt/model_utils.py b/models/demos/t3000/falcon40b/tt/model_utils.py index 25ba146554f..e3635da4699 100644 --- a/models/demos/t3000/falcon40b/tt/model_utils.py +++ b/models/demos/t3000/falcon40b/tt/model_utils.py @@ -433,7 +433,7 @@ def generate_layernorm_persistent_tensors(seq_len, slice_size, ln_output_tensors layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) if name in ln_output_tensors_dict and ln_output_tensors_dict[name] is not None: ln_output_tensors_dict[name].update({seq_len: output_tensor}) diff --git a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py index 02a6684d838..c08837dfd65 100644 --- a/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py +++ b/models/demos/t3000/llama2_70b/demo/demo_continuous_batching_paged_attention.py @@ -13,7 +13,7 @@ import pytest from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from transformers.generation.utils import top_k_top_p_filtering @@ -243,7 +243,7 @@ def run_decode( static_page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(model.mesh_device), ) page_table_tt = ttnn.to_device(page_table_tt, model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) diff --git a/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py b/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py index 22ba67ece5d..48dece25332 100644 --- a/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_chunked_generation.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_generation import ( diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py index 72bd9b7091f..e351736c14b 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_attention.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_attention.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_attention_optimized import TtLlamaAttention_optimized @@ -130,7 +130,7 @@ def tt_llama_attention_prepare_inputs( dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), device=llama_attention_model.mesh_device, ) xs = ttnn.to_device(xs, llama_attention_model.mesh_device) @@ -149,7 +149,7 @@ def tt_llama_attention_prepare_inputs( cache_file_name=cache_name(f"cos_gathered_prefill_{start_pos}_to_{start_pos + seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -158,7 +158,7 @@ def tt_llama_attention_prepare_inputs( cache_file_name=cache_name(f"sin_gathered_prefill_{start_pos}_to_{start_pos + seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_attention_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) cos_gathereds = ttnn.to_device(cos_gathereds, llama_attention_model.mesh_device) @@ -181,7 +181,7 @@ def tt_llama_attention_prepare_inputs( dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), device=llama_attention_model.mesh_device, ) xs = ttnn.to_device(xs, llama_attention_model.mesh_device) @@ -194,7 +194,7 @@ def tt_llama_attention_prepare_inputs( layout=ttnn.ROW_MAJOR_LAYOUT, device=llama_attention_model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_attention_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), ) rot_mats = rope_setup.get_rot_mats(cache_idxs) @@ -263,7 +263,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) transformation_mats = ttnn.to_device(transformation_mats, t3k_mesh_device) transformation_mats = {"prefill": transformation_mats} @@ -284,7 +284,7 @@ def run_test_LlamaAttention_inference( page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) page_table_tt = ttnn.to_device(page_table_tt, t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG) @@ -352,7 +352,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.ROW_MAJOR_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # SDPA requires that the page table batch dim matches the input batch dim, which must be 1 in prefill prefill_page_table = page_table[0:1, :] @@ -362,7 +362,7 @@ def run_test_LlamaAttention_inference( layout=ttnn.ROW_MAJOR_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) chunk_tt_input = tt_input[:, chunk_start:chunk_end] diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py index f57969b7c7f..5e94c4ddfc5 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.llama2_70b.reference.llama.llama import Llama from models.demos.t3000.llama2_70b.tt.llama_decoder_optimized import TtLlamaDecoder_optimized @@ -141,7 +141,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -150,7 +150,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=llama_decoder_model.mesh_device, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) cos_gathereds = ttnn.to_device(cos_gathereds, llama_decoder_model.mesh_device) sin_gathereds = ttnn.to_device(sin_gathereds, llama_decoder_model.mesh_device) @@ -184,7 +184,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop layout=ttnn.ROW_MAJOR_LAYOUT, device=llama_decoder_model.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(llama_decoder_model.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(llama_decoder_model.mesh_device), ) rot_mats = rope_setup.get_rot_mats(cache_idxs) @@ -248,7 +248,7 @@ def run_test_LlamaDecoder_inference( layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) transformation_mats = ttnn.to_device(transformation_mats, t3k_mesh_device) transformation_mats = {"prefill": transformation_mats} diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index babfe3b3657..2c1b66ceaa2 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -6,7 +6,7 @@ import torch from torch import nn import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor import scipy diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_model.py b/models/demos/t3000/llama2_70b/tests/test_llama_model.py index ef41fbe6d89..100dce9c12e 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_model.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_model.py @@ -6,7 +6,7 @@ from loguru import logger import torch import ttnn -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper import os import scipy diff --git a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py index a1bd6b1565e..46b8795ce9c 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py @@ -6,7 +6,7 @@ from typing import List import torch import ttnn -from ttnn import ReplicateTensorToMesh, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensorToMesh from models.demos.t3000.llama2_70b.tt.llama_attention_optimized import TtLlamaAttention_optimized from models.demos.t3000.llama2_70b.tt.llama_mlp_optimized import TtLlamaMLP_optimized @@ -106,7 +106,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / attn_norm_str, ) self.attn_norm = ttnn.to_device(attn_norm_ttnn, self.mesh_device) @@ -128,7 +128,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / ffn_norm_str, ) self.ffn_norm = ttnn.to_device(ffn_norm_ttnn, self.mesh_device) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 0aee8f7bf77..a1087e9d645 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -5,7 +5,7 @@ import math import torch import ttnn -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from loguru import logger diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 32bce8227ec..ceac83f3417 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32, profiler @@ -66,7 +66,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) transformation_mats_prefill = ttnn.to_device(transformation_mats_prefill, mesh_device) @@ -150,7 +150,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / norm_str, ) self.norm = ttnn.to_device(norm_ttnn, self.mesh_device) @@ -210,7 +210,7 @@ def prepare_inputs( inp_ids, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if mode == "prefill": @@ -235,7 +235,7 @@ def prepare_inputs( cache_file_name=cache_name(f"cos_gathered_prefill_{start_pos}_{start_pos+seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) sin_gathereds = ttnn.as_tensor( sin_gathered, @@ -244,7 +244,7 @@ def prepare_inputs( cache_file_name=cache_name(f"sin_gathered_prefill_{start_pos}_{start_pos+seq_len}"), memory_config=ttnn.DRAM_MEMORY_CONFIG, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) cos_gathereds = ttnn.to_device(cos_gathereds, self.mesh_device) sin_gathereds = ttnn.to_device(sin_gathereds, self.mesh_device) @@ -261,7 +261,7 @@ def prepare_inputs( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) if chunk_page_table is not None: chunk_page_table = ttnn.as_tensor( @@ -270,7 +270,7 @@ def prepare_inputs( memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return (xs, start_pos, rot_mats, rot_idxs_tt, cache_idxs_tt, page_table, chunk_page_table) @@ -288,7 +288,7 @@ def prepare_inputs( cache_idxs, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) rot_mats = None # Created in prepare_device_inputs @@ -303,7 +303,7 @@ def prepare_inputs( page_table, dtype=ttnn.int32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) return (xs, start_pos, rot_mats, rot_idxs_tt, cache_idxs_tt, page_table) diff --git a/models/demos/t3000/llama2_70b/tt/llama_rope.py b/models/demos/t3000/llama2_70b/tt/llama_rope.py index e7f4baeb4fd..02d16ea0271 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_rope.py +++ b/models/demos/t3000/llama2_70b/tt/llama_rope.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.lightweightmodule import LightweightModule from models.demos.t3000.llama2_70b.tt.llama_common import precompute_freqs, get_rot_transformation_mat, gather_cos_sin from loguru import logger @@ -48,14 +48,14 @@ def __init__( device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) self.sin_matrix = ttnn.from_torch( sin_matrix, device=device, layout=ttnn.ROW_MAJOR_LAYOUT, dtype=datatype, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) # Generate the transformation matrix @@ -74,7 +74,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, dtype=datatype, memory_config=trans_mat_mem_config, - mesh_mapper=ReplicateTensorToMesh(device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(device) if self.is_mesh_device else None, ) def get_trans_mats(self): @@ -93,7 +93,7 @@ def get_rot_idxs(self, position_idxs): position_idxs, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(self.device) if self.is_mesh_device else None, + ttnn.replicate_tensor_to_mesh_mapper(self.device) if self.is_mesh_device else None, ) return rot_idxs diff --git a/models/demos/t3000/mixtral8x7b/demo/demo.py b/models/demos/t3000/mixtral8x7b/demo/demo.py index be02adcf491..afcf5b15d92 100644 --- a/models/demos/t3000/mixtral8x7b/demo/demo.py +++ b/models/demos/t3000/mixtral8x7b/demo/demo.py @@ -9,7 +9,7 @@ from time import time import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( load_inputs, preprocess_inputs, diff --git a/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py b/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py index 408b223e3cf..2de3717c87b 100644 --- a/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py +++ b/models/demos/t3000/mixtral8x7b/demo/demo_with_prefill.py @@ -9,7 +9,7 @@ from time import time import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( load_inputs, preprocess_inputs_prefill, @@ -178,7 +178,7 @@ def run_mixtral_demo(user_input, batch_size, mesh_device, instruct_mode, test_pr layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) profiler.end("prepare_rot_mat_for_prefill") diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py index 957be57c7de..a2c04f6c58f 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_attention import TtMixtralAttention from models.demos.t3000.mixtral8x7b.tt.mixtral_common import prepare_inputs_ttnn, get_single_rot_mat from models.demos.t3000.mixtral8x7b.reference.model import Attention, precompute_freqs_cis diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py index d4e50a5f5cb..2f78e8254f5 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_attention_prefill.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_attention import TtMixtralAttention from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( prepare_inputs_ttnn_prefill, @@ -59,7 +59,7 @@ def test_mixtral_attention_inference(t3k_mesh_device, use_program_cache, reset_s layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Load ttnn model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py index 36b035b536e..416e81d59a1 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py @@ -12,7 +12,7 @@ from models.demos.t3000.mixtral8x7b.reference.model import TransformerBlock, precompute_freqs_cis from models.demos.t3000.mixtral8x7b.tt.model_config import TtModelArgs from models.utility_functions import comp_pcc, comp_allclose -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize( diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py index dc4b84ba4ef..8ff473a67d5 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder_prefill.py @@ -17,7 +17,7 @@ from models.demos.t3000.mixtral8x7b.reference.model import TransformerBlock, precompute_freqs_cis, RMSNorm from models.demos.t3000.mixtral8x7b.tt.model_config import TtModelArgs from models.utility_functions import comp_pcc, comp_allclose -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize( @@ -57,7 +57,7 @@ def test_mixtral_decoder_inference(t3k_mesh_device, use_program_cache, reset_see layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Initialize TT model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py index 932a60af16f..79f96b157ca 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py @@ -6,7 +6,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.reference.model import FeedForward, RMSNorm @@ -62,7 +62,7 @@ def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py index 7e952a57d98..55412b4f9c2 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp_prefill.py @@ -7,7 +7,7 @@ import pytest import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.reference.model import FeedForward, RMSNorm @@ -73,7 +73,7 @@ def test_mixtral_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds, dtype=ttnn.bfloat8_b, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_input = ttnn.to_device(tt_input, t3k_mesh_device) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py index afb36a0a7f6..f0fe2b92fd4 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py @@ -9,7 +9,7 @@ from sklearn.metrics import top_k_accuracy_score import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import prepare_inputs_ttnn from models.demos.t3000.mixtral8x7b.tt.mixtral_model import TtTransformer diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py index 876392eecc8..672a2995152 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_model_prefill.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( prepare_inputs_ttnn_prefill, @@ -89,7 +89,7 @@ def test_mixtral_model_inference_CI(t3k_mesh_device, use_program_cache, reset_se layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Load TTNN model tt_model = TtTransformer( diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py index 10a1e2e0bc9..60a683f534c 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.tt.mixtral_moe import TtMoeLayer @@ -84,7 +84,7 @@ def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Run TT model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py index 5e8df333fd7..a7d788ac8de 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe_prefill.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_mlp import TtMixtralMLP from models.demos.t3000.mixtral8x7b.tt.mixtral_moe import TtMoeLayer @@ -91,7 +91,7 @@ def test_mixtral_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds, dtype=ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Run TT model diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py index 1fa29fac602..81c8b6558c7 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py @@ -6,7 +6,7 @@ import pytest import ttnn -from ttnn import ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( preprocess_inputs_prefill, @@ -327,7 +327,7 @@ def run_inference_prefill(tt_model, model_args, prefill_seqlen, mesh_device, pt_ layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) profiler.end("prefill_prepare_rot_matrices") diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py index 1418f44c19d..459e1a4ffad 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perplexity.py @@ -12,7 +12,7 @@ import numpy as np import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.mixtral8x7b.tt.mixtral_common import ( prepare_inputs_ttnn, get_single_rot_mat, diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py index df46b58b1d0..601fb7e41af 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.rmsnorm import RMSNorm as TtRMSNorm from models.demos.t3000.mixtral8x7b.reference.model import RMSNorm as RefRMSNorm @@ -50,7 +50,7 @@ def test_mixtral_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_se device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input, mode="decode") diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 1b27ad4a3f3..641c42dd847 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -5,7 +5,7 @@ import torch import ttnn from models.utility_functions import nearest_32 -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.lightweightmodule import LightweightModule @@ -135,7 +135,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): self.reduce_mask = ttnn.from_torch( reduce_mask_torch, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, ) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py index 8a061d72b2e..dfdd8179bbf 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 import json import math @@ -55,7 +55,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] @@ -65,7 +65,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] @@ -183,7 +183,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) return xs_1SBH @@ -224,7 +224,7 @@ def cache_attention(mesh_device, state_dict, model_args, current_rot_mat, rot_ma layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_attn = TtMixtralAttention( @@ -295,13 +295,13 @@ def get_single_rot_mat(dhead, mesh_device, start_pos=0, theta: float = 1000000.0 device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0), # 1,1,head_dim,head_dim device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) @@ -330,13 +330,13 @@ def get_single_rot_mat_multi_pos(dhead, mesh_device, start_pos_ids, theta: float device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ttnn.from_torch( rot_matrix.unsqueeze(0).unsqueeze(0).repeat(1, len(start_pos_ids), 1, 1), # 1,1,head_dim,head_dim device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) @@ -376,14 +376,14 @@ def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len): cos_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) sin_gathereds = ttnn.from_torch( sin_gathered, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) @@ -421,7 +421,7 @@ def prepare_inputs_ttnn_prefill(x_bsh, mesh_device, num_tokens=None): dtype=attn_mask_dtype, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # input goes to L1 @@ -431,7 +431,7 @@ def prepare_inputs_ttnn_prefill(x_bsh, mesh_device, num_tokens=None): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) return xs_1BSH, attn_mask, attn_mask_torch diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py index 093b3c8f7b4..b9288d09fb7 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_model.py @@ -7,7 +7,7 @@ from models.common.rmsnorm import RMSNorm from models.common.lightweightmodule import LightweightModule from models.demos.t3000.mixtral8x7b.tt.mixtral_common import get_single_rot_mat_multi_pos, get_single_rot_mat_torch -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper import torch @@ -55,7 +55,7 @@ def __init__(self, mesh_device, state_dict, args, dtype, layers, start_pos_ids, dtype=dtype, memory_config=self.model_config["OUTPUT_WEIGHTS_MEMCFG"], cache_file_name=output_cache_name, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.compute_kernel = self.args.get_compute_kernel_config() @@ -86,7 +86,7 @@ def forward( device=self.mesh_device, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) else: rot_mats = self.current_rot_mat diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py index c71feb93bf5..1fc4d945045 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper from models.common.lightweightmodule import LightweightModule @@ -58,7 +58,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.top8_mask_11B_64 = ttnn.sum(self.top8_mask_11B_64, dim=2) @@ -69,7 +69,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.top2_mask_11BB = ttnn.sum(self.top2_mask_11BB, dim=2) @@ -81,7 +81,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward(self, inputs, mode="decode"): diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index 98322a8f0c6..edbb7a18da1 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -21,7 +21,7 @@ import transformers from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -144,7 +144,7 @@ def test_falcon_attention( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ) tt_FalconAttention_model = TtFalconAttention( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py index 1de4f9a058c..c74076bd75e 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py @@ -22,7 +22,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -153,7 +153,7 @@ def convert_to_ttnn(model, name): model_config, tt_cache_path=get_tt_cache_path(f"{model_version}"), device=mesh_device, - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), convert_to_ttnn=convert_to_ttnn, ) @@ -369,7 +369,7 @@ def convert_to_ttnn(model, name): model_config, tt_cache_path=get_tt_cache_path(f"{model_version}"), device=t3k_mesh_device, - weights_mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ), convert_to_ttnn=convert_to_ttnn, ) @@ -393,7 +393,7 @@ def convert_to_ttnn(model, name): torch.full(scalar_shape, layer.self_attn.scalar), device=t3k_mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # TODO: Generate embeddings and attention_mask on device tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py index 40676143a5b..ea7aad0cc7f 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py @@ -21,7 +21,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -142,7 +142,7 @@ def test_falcon_decoder( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ) tt_FalconDecoder_model = TtFalconDecoderLayer( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index c118f9a9b15..5edb9f55ef5 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -11,7 +11,7 @@ from models.demos.ttnn_falcon7b.tt.common import create_custom_preprocessor, strip_state_dict_prefix from ttnn.model_preprocessing import preprocess_model_parameters from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor import transformers from loguru import logger @@ -88,7 +88,7 @@ def test_falcon_mlp( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), ) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py index 31c4d04816a..8786d1d722e 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py @@ -23,7 +23,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -157,7 +157,7 @@ def convert_to_ttnn(model, name): tt_cache_path=get_tt_cache_path(f"{model_version}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_mesh_mapper=ReplicateTensorToMesh(mesh_device), + weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ), convert_to_ttnn=convert_to_ttnn, ) diff --git a/models/demos/ttnn_falcon7b/tt/falcon_model.py b/models/demos/ttnn_falcon7b/tt/falcon_model.py index f27f1122947..a54d8e36242 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_model.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_model.py @@ -10,7 +10,7 @@ from models.demos.ttnn_falcon7b.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.ttnn_falcon7b.tt.common import create_attention_mask -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor class TtFalconModelShared: diff --git a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py index 2866840ad8d..559525e51a6 100644 --- a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py +++ b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py @@ -261,7 +261,7 @@ def get_mesh_mappers(self, device): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) - weights_mesh_mapper = None # ttnn.ReplicateTensorToMesh(device) causes unnecessary replication/takes more time on the first pass + weights_mesh_mapper = None # ttnn.replicate_tensor_to_mesh_mapper(device) causes unnecessary replication/takes more time on the first pass output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) else: inputs_mesh_mapper = None diff --git a/models/demos/wormhole/bert_tiny/demo/demo.py b/models/demos/wormhole/bert_tiny/demo/demo.py index fe403df2338..05f92393448 100644 --- a/models/demos/wormhole/bert_tiny/demo/demo.py +++ b/models/demos/wormhole/bert_tiny/demo/demo.py @@ -73,7 +73,7 @@ def run_bert_question_and_answering_inference( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_model, device=mesh_device, @@ -191,7 +191,7 @@ def run_bert_question_and_answering_inference_squad_v2( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_model, device=mesh_device, diff --git a/models/demos/wormhole/bert_tiny/tests/test_performance.py b/models/demos/wormhole/bert_tiny/tests/test_performance.py index bcc438ea198..5e9029ce84b 100644 --- a/models/demos/wormhole/bert_tiny/tests/test_performance.py +++ b/models/demos/wormhole/bert_tiny/tests/test_performance.py @@ -54,7 +54,7 @@ def test_perf_bert_tiny( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_model, device=mesh_device, @@ -73,7 +73,7 @@ def test_perf_bert_tiny( ttnn_attention_mask = ttnn.from_torch( torch_attention_mask, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) durations = [] diff --git a/models/demos/wormhole/distilbert/demo/demo.py b/models/demos/wormhole/distilbert/demo/demo.py index dfd89c18939..51ba895798e 100644 --- a/models/demos/wormhole/distilbert/demo/demo.py +++ b/models/demos/wormhole/distilbert/demo/demo.py @@ -51,12 +51,12 @@ def run_distilbert_question_and_answering_inference( tt_model_name = f"ttnn_{model_name}_optimized" inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) profiler.start(f"preprocessing_parameter") - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, @@ -192,10 +192,10 @@ def run_distilbert_question_and_answering_inference_squad_v2( tt_model_name = f"ttnn_{model_name}_optimized" inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, diff --git a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py index a3fad4aa54c..77c0232a765 100644 --- a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py +++ b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py @@ -68,10 +68,10 @@ def test_performance_distilbert_for_qa( tt_model_name = f"ttnn_{model_name}_optimized" inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) profiler.start(f"preprocessing_parameter") - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, diff --git a/models/experimental/functional_unet/tests/test_unet_bottleneck.py b/models/experimental/functional_unet/tests/test_unet_bottleneck.py index c78de65acaf..ce94ef1ccaa 100644 --- a/models/experimental/functional_unet/tests/test_unet_bottleneck.py +++ b/models/experimental/functional_unet/tests/test_unet_bottleneck.py @@ -53,7 +53,7 @@ def test_unet_bottleneck_multi_device( pytest.skip("Test is only valid for N300") inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_downblock.py b/models/experimental/functional_unet/tests/test_unet_downblock.py index 1ea2633b2ad..69231d68504 100644 --- a/models/experimental/functional_unet/tests/test_unet_downblock.py +++ b/models/experimental/functional_unet/tests/test_unet_downblock.py @@ -82,7 +82,7 @@ def test_unet_downblock_multi_device( pytest.skip("Test is only valid for N300") inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_multi_device.py b/models/experimental/functional_unet/tests/test_unet_multi_device.py index f611ce3999c..c21759f72a6 100644 --- a/models/experimental/functional_unet/tests/test_unet_multi_device.py +++ b/models/experimental/functional_unet/tests/test_unet_multi_device.py @@ -29,7 +29,7 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache, pytest.skip("Test is only valid for N300 or T3000") inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_trace.py b/models/experimental/functional_unet/tests/test_unet_trace.py index 17211134106..74045f761d5 100644 --- a/models/experimental/functional_unet/tests/test_unet_trace.py +++ b/models/experimental/functional_unet/tests/test_unet_trace.py @@ -232,7 +232,7 @@ def test_unet_trace_2cq_multi_device( pytest.skip("Test is only valid for N300 or T3000") inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) @@ -490,7 +490,7 @@ def test_unet_trace_2cq_same_io_multi_device( pytest.skip("Test is only valid for N300 or T3000") inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/functional_unet/tests/test_unet_upblock.py b/models/experimental/functional_unet/tests/test_unet_upblock.py index 9c623d1c840..15566539c22 100644 --- a/models/experimental/functional_unet/tests/test_unet_upblock.py +++ b/models/experimental/functional_unet/tests/test_unet_upblock.py @@ -99,7 +99,7 @@ def test_unet_upblock_multi_device( pytest.skip("Test is only valid for N300") inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) torch_input, ttnn_input = create_unet_input_tensors(batch, groups) diff --git a/models/experimental/grok/demo/demo.py b/models/experimental/grok/demo/demo.py index 1a8a477506d..b41b689e237 100644 --- a/models/experimental/grok/demo/demo.py +++ b/models/experimental/grok/demo/demo.py @@ -17,7 +17,7 @@ os.environ["WH_ARCH_YAML"] = "wormhole_b0_80_arch_eth_dispatch.yaml" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_common import ( prepare_inputs_ttnn, prepare_rotation_mat_ttnn, @@ -85,7 +85,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.uint32, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] @@ -95,7 +95,7 @@ def preprocess_inputs(input_prompts, tokenizer, model_args, dtype, instruct, mes device=mesh_device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(max_prompt_len) ] diff --git a/models/experimental/grok/tests/test_grok_decoder.py b/models/experimental/grok/tests/test_grok_decoder.py index aa3b1c6ce00..b2f7618215d 100644 --- a/models/experimental/grok/tests/test_grok_decoder.py +++ b/models/experimental/grok/tests/test_grok_decoder.py @@ -18,7 +18,7 @@ from models.experimental.grok.reference.model import DecoderLayer from models.experimental.grok.tt.model_config import TtModelArgs from models.utility_functions import comp_pcc, comp_allclose -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.timeout(500 * 8) diff --git a/models/experimental/grok/tests/test_grok_mlp.py b/models/experimental/grok/tests/test_grok_mlp.py index d5a154ce8fe..adc730df7fc 100644 --- a/models/experimental/grok/tests/test_grok_mlp.py +++ b/models/experimental/grok/tests/test_grok_mlp.py @@ -13,7 +13,7 @@ os.environ["GROK_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/Grok/Grok-1/" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_mlp import TtGrokMLP from models.experimental.grok.reference.model import MoeMLP, RMSNorm @@ -70,7 +70,7 @@ def test_grok_mlp_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/experimental/grok/tests/test_grok_moe.py b/models/experimental/grok/tests/test_grok_moe.py index ee6c77e6553..c36d85528a3 100644 --- a/models/experimental/grok/tests/test_grok_moe.py +++ b/models/experimental/grok/tests/test_grok_moe.py @@ -13,7 +13,7 @@ os.environ["GROK_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/Grok/Grok-1/" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_mlp import TtGrokMLP from models.experimental.grok.tt.grok_moe import TtMoeLayer @@ -86,7 +86,7 @@ def test_grok_moe_inference(t3k_mesh_device, use_program_cache, reset_seeds): dtype=ttnn.bfloat16, memory_config=ttnn.L1_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) # Run TT model tt_out = tt_model(tt_decode_input) diff --git a/models/experimental/grok/tests/test_grok_rms_norm.py b/models/experimental/grok/tests/test_grok_rms_norm.py index 5f220b9eb2b..e080ec5bebe 100644 --- a/models/experimental/grok/tests/test_grok_rms_norm.py +++ b/models/experimental/grok/tests/test_grok_rms_norm.py @@ -13,7 +13,7 @@ os.environ["GROK_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/Grok/Grok-1/" import ttnn -from ttnn import ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_rms_norm import TtRMSNorm, TtRMSNormSharded from models.experimental.grok.reference.model import RMSNorm @@ -55,7 +55,7 @@ def test_grok_rms_norm_inference(t3k_mesh_device, use_program_cache, reset_seeds device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) @@ -104,7 +104,7 @@ def test_grok_rms_norm_sharded_inference(t3k_mesh_device, use_program_cache, res device=t3k_mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) tt_output = tt_model(tt_input) diff --git a/models/experimental/grok/tt/grok_attention.py b/models/experimental/grok/tt/grok_attention.py index 794c6daa784..38fbdca652f 100644 --- a/models/experimental/grok/tt/grok_attention.py +++ b/models/experimental/grok/tt/grok_attention.py @@ -5,7 +5,7 @@ import torch import ttnn from models.utility_functions import nearest_32 -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_common import LightweightModule @@ -91,7 +91,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(self.mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], diff --git a/models/experimental/grok/tt/grok_common.py b/models/experimental/grok/tt/grok_common.py index 08181844cfa..ade352ee7d1 100644 --- a/models/experimental/grok/tt/grok_common.py +++ b/models/experimental/grok/tt/grok_common.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 @@ -78,7 +78,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # Attention mask @@ -94,7 +94,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, mesh_device): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( @@ -121,7 +121,7 @@ def prepare_rotation_mat_ttnn(head_dim, max_seq_len, mesh_device): device=mesh_device, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for rot_mat_i in rot_mat ] @@ -163,7 +163,7 @@ def cache_attention(mesh_device, state_dict, model_args, rot_emb_matrix_list, se layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ttnn.L1_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_attn = TtGrokAttention( @@ -185,7 +185,7 @@ def cache_attention(mesh_device, state_dict, model_args, rot_emb_matrix_list, se dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( diff --git a/models/experimental/grok/tt/grok_moe.py b/models/experimental/grok/tt/grok_moe.py index 82526c3292a..e40c00e08e2 100644 --- a/models/experimental/grok/tt/grok_moe.py +++ b/models/experimental/grok/tt/grok_moe.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule from models.experimental.grok.scripts.tlog import tlog, tlog_mesh_device @@ -34,7 +34,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): memory_config=self.model_config["GATE_WEIGHTS_MEMCFG"], cache_file_name=cache_name, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.num_devices = 8 @@ -48,7 +48,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.expert_mask_11BB = ttnn.from_torch( torch.cat([torch.full((1, 1, 32, 32), fill_value=i + 1) for i in range(8)], dim=3), @@ -64,7 +64,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) top2_mask = torch.full((1, 1, 32, 32), fill_value=0.0) @@ -74,7 +74,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.softmax_compute_config = ttnn.WormholeComputeKernelConfig( math_fidelity=ttnn.MathFidelity.HiFi4, math_approx_mode=False, fp32_dest_acc_en=True, packer_l1_acc=True diff --git a/models/experimental/grok/tt/grok_rms_norm.py b/models/experimental/grok/tt/grok_rms_norm.py index b337ab81381..08166210307 100644 --- a/models/experimental/grok/tt/grok_rms_norm.py +++ b/models/experimental/grok/tt/grok_rms_norm.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch import ttnn -from ttnn import ReplicateTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule @@ -43,7 +43,7 @@ def __init__( layout=self.model_config["NORM_W_LAYOUT_TILE"], memory_config=self.model_config["NORM_WEIGHTS_MEMCFG"], cache_file_name=cache_name, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: @@ -88,7 +88,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["NORM_WEIGHTS_MEMCFG"], cache_file_name=cache_name, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) def forward(self, x: ttnn.Tensor, out_sharded=False) -> ttnn.Tensor: diff --git a/tech_reports/CNNs/cnn_optimizations.md b/tech_reports/CNNs/cnn_optimizations.md index 5afd41adb32..49b656ed21e 100644 --- a/tech_reports/CNNs/cnn_optimizations.md +++ b/tech_reports/CNNs/cnn_optimizations.md @@ -197,7 +197,7 @@ Throughput can be improved if multiple chips are availible by replicating the CN ```python inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) # Shard input tensor on dimension 0 across each device -weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) # Replicate weights across all devices +weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) # Replicate weights across all devices output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) # Map multi-device tensor back to single host tensor ``` diff --git a/tech_reports/LLMs/llms.md b/tech_reports/LLMs/llms.md index 0342e432399..1334618d9c4 100644 --- a/tech_reports/LLMs/llms.md +++ b/tech_reports/LLMs/llms.md @@ -1215,7 +1215,7 @@ mesh_tensor_replicated = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ``` diff --git a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md index 862921f5d33..760d78837ea 100644 --- a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md +++ b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md @@ -461,7 +461,7 @@ with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=0)): ) # Replicate model parameters to devices in the mesh -with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): +with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = ttnn.model_preprocessing.preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, @@ -539,7 +539,7 @@ mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4)) # Initialize input activations on all devices in the mesh # Alternatively, we can shard the input activations on the height dimension and # subsequently invoke all-gather on the height dimension to form a complete tensor per device. -with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): +with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, diff --git a/tests/ttnn/distributed/test_data_parallel_example.py b/tests/ttnn/distributed/test_data_parallel_example.py index fb5f59568c0..2af396f4d7d 100644 --- a/tests/ttnn/distributed/test_data_parallel_example.py +++ b/tests/ttnn/distributed/test_data_parallel_example.py @@ -46,7 +46,7 @@ def test_data_parallel_falcon_mlp(mesh_device): ) # Replicate model parameters to devices in the mesh - with ttnn.distribute(mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, diff --git a/tests/ttnn/distributed/test_data_parallel_example_TG.py b/tests/ttnn/distributed/test_data_parallel_example_TG.py index 66b8bcacb5b..ae45f63b467 100644 --- a/tests/ttnn/distributed/test_data_parallel_example_TG.py +++ b/tests/ttnn/distributed/test_data_parallel_example_TG.py @@ -48,7 +48,7 @@ def test_data_parallel_falcon_mlp(mesh_device): ) # Replicate model parameters to devices in the mesh - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index 82b4381c4aa..a98aaa9f540 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -13,7 +13,7 @@ from ttnn import ( ShardTensorToMesh, ShardTensor2dMesh, - ReplicateTensorToMesh, + ttnn.replicate_tensor_to_mesh_mapper, ConcatMeshToTensor, ConcatMesh2dToTensor, MeshToTensor, @@ -39,7 +39,7 @@ def test_galaxy_matmul_1d_fracture(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) weights = ttnn.from_torch( weights_pt, @@ -362,7 +362,7 @@ def test_galaxy_eltwise_add(M, N, mesh_device): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=LN_OUTPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) attn_output = ttnn.from_torch( @@ -371,7 +371,7 @@ def test_galaxy_eltwise_add(M, N, mesh_device): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=LN_OUTPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) gt = residual_pt + attn_output_pt @@ -420,7 +420,7 @@ def test_galaxy_attn_matmul(M, N, head_dim, num_heads, mesh_shape, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) weights = ttnn.from_torch( @@ -536,7 +536,7 @@ def test_galaxy_nlp_create_heads_decode( layout=ttnn.TILE_LAYOUT, memory_config=CREATE_HEAD_INPUT_MEMCFG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) # tt operation @@ -636,7 +636,7 @@ def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_k layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ROTARY_INPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) key_layer = ttnn.from_torch( @@ -645,7 +645,7 @@ def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_k layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ROTARY_INPUT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) rot_mats = ttnn.from_torch( @@ -654,7 +654,7 @@ def test_galaxy_rotary_matmul(batch, seq_len, head_dim, n_local_heads, n_local_k layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=ROT_MAT_MEMCFG, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) compute_kernel_rotary = ttnn.WormholeComputeKernelConfig( @@ -725,7 +725,7 @@ def test_fill_cache( dtype=cache_dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) for i in range(num_users): x = torch.randn(input_shape).bfloat16().float() @@ -753,7 +753,7 @@ def test_fill_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=input_mem_config, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) cachett = ttnn.fill_cache(cachett, xt, i) @@ -794,7 +794,7 @@ def test_update_cache_decode( dtype=cache_dtype, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) x = torch.randn(input_shape).bfloat16().float() @@ -828,7 +828,7 @@ def test_update_cache_decode( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=input_mem_config, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) cachett = ttnn.update_cache(cachett, xt, cache_idx, batch_offset=batch_offset) @@ -924,7 +924,7 @@ def run_test_sdpa_decode_single_iter( dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_V = ttnn.from_torch( @@ -933,7 +933,7 @@ def run_test_sdpa_decode_single_iter( dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) start_idx = s // 2 scale = d**-0.5 @@ -965,7 +965,7 @@ def run_test_sdpa_decode_single_iter( dtype=dtype, layout=ttnn.TILE_LAYOUT, memory_config=dram_memcfg, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tt_back = ttnn.transformer.scaled_dot_product_attention_decode( @@ -1064,7 +1064,7 @@ def test_galaxy_nlp_concat_heads_decode( layout=ttnn.TILE_LAYOUT, memory_config=CONCAT_HEADS_INPUT_MEMCFG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) concat_head_output = ttnn.experimental.nlp_concat_heads_decode( @@ -1151,7 +1151,7 @@ def test_galaxy_layernorm(M, N, mesh_device): layout=ttnn.TILE_LAYOUT, memory_config=LN_OUTPUT_MEMCFG, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) norm_weights_tt = ttnn.from_torch( @@ -1159,7 +1159,7 @@ def test_galaxy_layernorm(M, N, mesh_device): dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) norm_output = ttnn.rms_norm( diff --git a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py index 87c38fc5780..72a81cd766a 100644 --- a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py +++ b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py @@ -53,7 +53,7 @@ def test_tensor_parallel_falcon_mlp(): # Initialize input activations on all devices in the mesh # Alternatively, we can shard the input activations on the height dimension and # subsequently invoke all-gather on the height dimension to form a complete tensor per device. - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, diff --git a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py index d309befa0b2..b4a023a1235 100644 --- a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py +++ b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py @@ -36,7 +36,7 @@ def test_bert_attention_inference( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_attention_model, device=mesh_device, @@ -93,7 +93,7 @@ def test_bert_intermediate_inference( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_intermediate_model, device=mesh_device, @@ -140,7 +140,7 @@ def test_bert_output_inference( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_output_model, device=mesh_device, @@ -197,7 +197,7 @@ def test_bert_layer_inference( inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: pytorch_layer_model, device=mesh_device, @@ -247,7 +247,7 @@ def test_bert_for_question_answering(mesh_device, model_name, sequence_size, num inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, @@ -277,7 +277,7 @@ def test_bert_for_question_answering(mesh_device, model_name, sequence_size, num ttnn_attention_mask = ttnn.from_torch( torch_attention_mask, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), device=mesh_device, ) diff --git a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py index 5d2dd6284bd..2cea87ab1d8 100644 --- a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py +++ b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py @@ -29,13 +29,13 @@ def test_distilbert_for_question_answering(mesh_device, model_name, batch_size, tt_model_name = f"ttnn_{model_name}_optimized" inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) - weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) if ttnn.GetNumAvailableDevices() == 2: batch_size = batch_size * 2 - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( model_name=tt_model_name, initialize_model=lambda: HF_model, diff --git a/tests/ttnn/unit_tests/operations/prefetcher_common.py b/tests/ttnn/unit_tests/operations/prefetcher_common.py index bfc881c16dc..1d91ac26f3c 100644 --- a/tests/ttnn/unit_tests/operations/prefetcher_common.py +++ b/tests/ttnn/unit_tests/operations/prefetcher_common.py @@ -8,7 +8,7 @@ import math from loguru import logger -from ttnn import ReplicateTensorToMesh, ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor +from ttnn import replicate_tensor_to_mesh_mapper, ShardTensor2dMesh, ConcatMeshToTensor, ConcatMesh2dToTensor from models.common.lightweightmodule import LightweightModule from tests.ttnn.utils_for_testing import assert_with_pcc from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( @@ -228,7 +228,7 @@ def run_prefetcher_mm( mesh_composer = None if isinstance(device, ttnn._ttnn.multi_device.MeshDevice): cluster_shape = device.shape - mesh_mapper = ReplicateTensorToMesh(device) + mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(device) mesh_composer = ConcatMesh2dToTensor(device, dims=(0, 1), mesh_shape=cluster_shape) pt_tensors = [] diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index dbc28079e16..a699975ca38 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -488,7 +488,7 @@ def test_conv_features_multi_device( output_layout=output_layout, has_bias=True, input_mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0), - weight_mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + weight_mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), output_mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0), groups=groups, ) diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py b/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py index 029da544301..0fb0572f80d 100644 --- a/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py +++ b/tests/ttnn/unit_tests/tensor/test_tensor_prealloc_and_write.py @@ -74,7 +74,7 @@ def test_tensor_preallocation_and_write_apis( input_tensor_a, dtype=in_dtype, layout=tensor_layout, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn.copy_host_to_device_tensor(tt_input_tensor_a, preallocated_tensor) readback_tensors = [ diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index 231fa015962..2be0ec58257 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -11,7 +11,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor ####### @@ -182,14 +182,14 @@ def test_multi_device_check_per_device_shard(mesh_device, layout, memory_config, @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) def test_multi_device_replicate(mesh_device, shape, layout, memory_config): - """Test ReplicateTensorToMesh to broadcast a tensor across multiple devices""" - from ttnn import ReplicateTensorToMesh + """Test replicate_tensor_to_mesh_mapper to broadcast a tensor across multiple devices""" + from ttnn import replicate_tensor_to_mesh_mapper full_tensor = torch.rand(shape, dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( full_tensor, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(mesh_device), layout=layout, memory_config=memory_config, device=mesh_device, @@ -320,7 +320,7 @@ def test_multi_device_data_parallel_matmul_op(mesh_device): torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn_output_tensor = ttnn_input_a_tensor @ ttnn_input_b_tensor @@ -360,7 +360,7 @@ def test_multi_device_as_tensor_api(mesh_device, layout, memory_config, dtype): device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn_input_b_tensor = ttnn.as_tensor( @@ -370,7 +370,7 @@ def test_multi_device_as_tensor_api(mesh_device, layout, memory_config, dtype): device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) ttnn_output_tensor = ttnn_input_a_tensor @ ttnn_input_b_tensor @@ -457,7 +457,7 @@ def test_max(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) gate_logits_1SB8 = ttnn.to_device(gate_logits_1SB8, mesh_device) weights_ex0_1SB1 = ttnn.max(gate_logits_1SB8, dim=3) @@ -499,14 +499,14 @@ def test_sharded_matmul(t3k_mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + mesh_mapperttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) keys_1BDP = ttnn.from_torch( torch.randn(1, 32, 128, 32), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=t3k_mesh_device, - mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device), + mesh_mapperttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device), ) q_heads_1B4D = ttnn.to_device(q_heads_1B4D, t3k_mesh_device) @@ -561,7 +561,7 @@ def test_4b_tensor(mesh_device): dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tensor = ttnn.to_device(tensor, mesh_device) x = ttnn.from_torch( @@ -569,7 +569,7 @@ def test_4b_tensor(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) x = ttnn.to_device(x, mesh_device) tensor = ttnn.matmul( @@ -588,7 +588,7 @@ def test_slicing(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) tensor = ttnn.to_device(tensor, mesh_device) tensor = tensor[:, :, :, :1] @@ -601,7 +601,7 @@ def test_clone(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) results_11BH = ttnn.to_device(results_11BH, mesh_device) results_11BH = ttnn.clone(results_11BH, dtype=ttnn.bfloat8_b, memory_config=ttnn.L1_MEMORY_CONFIG) @@ -643,7 +643,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=memory_config, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=tmp_path / "cache_file", ) assert tensor.dtype == ttnn.float32 @@ -657,7 +657,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=memory_config, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=tmp_path / "cache_file", ) assert tensor.dtype == ttnn.float32 @@ -724,7 +724,7 @@ def test_line_all_gather_after_reshape(mesh_device): def test_distribute_api(mesh_device): torch_hidden_states = torch.rand((1, 1, 32, 32), dtype=torch.bfloat16) - with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): + with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat8_b, diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 3b1e75f500d..86baa78f1ef 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -83,8 +83,8 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): - """Test ReplicateTensorToMesh to broadcast a tensor across multiple devices""" - from ttnn import ReplicateTensorToMesh + """Test ttnn.replicate_tensor_to_mesh_mapper to broadcast a tensor across multiple devices""" + from ttnn import replicate_tensor_to_mesh_mapper pcie_mesh_device.enable_async(True) @@ -93,7 +93,7 @@ def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): ttnn_tensor = ttnn.from_torch( full_tensor, - mesh_mapper=ReplicateTensorToMesh(pcie_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), layout=layout, memory_config=memory_config, device=pcie_mesh_device, @@ -184,7 +184,7 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha @pytest.mark.parametrize("input_a_shape", [(4, 1, 512, 512), (16, 1, 512, 512)]) def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, input_a_shape): """Multidevice API: Running data-parallel chain of ops with matmul""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh + from ttnn import ShardTensorToMesh, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper pcie_mesh_device.enable_async(True) if program_cache: @@ -212,7 +212,7 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ReplicateTensorToMesh(pcie_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) ttnn_output_tensor = ttnn.from_device( ttnn.mish( @@ -249,7 +249,7 @@ def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): layout=layout, device=pcie_mesh_device, memory_config=mem_config, - mesh_mapper=ttnn.ReplicateTensorToMesh(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) tt_out_11BH = ttnn.argmax(tt_out_11BH, dim=-1) @@ -264,7 +264,7 @@ def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): @pytest.mark.parametrize("pcie_mesh_device", [2], indirect=True) def test_multi_device_explicit_dealloc(pcie_mesh_device): """Multidevice API: Ensure that deallocating multi-device tensors works as expected""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh + from ttnn import ShardTensorToMesh, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper if pcie_mesh_device.get_num_devices() <= 1: pytest.skip("Requires multiple devices to run") @@ -284,7 +284,7 @@ def test_multi_device_explicit_dealloc(pcie_mesh_device): torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ReplicateTensorToMesh(pcie_mesh_device), + ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) ttnn_output_tensor_1 = ttnn_input_a_tensor @ ttnn_input_b_tensor ttnn_output_tensor_2 = ttnn.gelu(ttnn_output_tensor_1) @@ -315,7 +315,7 @@ def test_add_1D_tensor_and_scalar(pcie_mesh_device, scalar, size): torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ttnn.ReplicateTensorToMesh(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) output_tensor = input_tensor + scalar output_tensors = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(output_tensor.cpu())] diff --git a/tests/ttnn/unit_tests/test_multi_device_events.py b/tests/ttnn/unit_tests/test_multi_device_events.py index b41c7cfaa3d..2a162932532 100644 --- a/tests/ttnn/unit_tests/test_multi_device_events.py +++ b/tests/ttnn/unit_tests/test_multi_device_events.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize("shape", [(1, 1, 512, 512)]) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index 284a75c0a60..5ea92bf15d0 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -246,7 +246,7 @@ def event_sync(device, record_cq, wait_cq): torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( - torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(t3k_mesh_device) + torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device) ) # Copy TTNN host tensors into preallocated Mult-Device tensors diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py index 5c24ab237aa..bf0c10fae83 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -224,7 +224,7 @@ def event_sync(device, record_cq, wait_cq): torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( - torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(mesh_device) + torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(mesh_device) ) # Copy TTNN host tensors into preallocated Mult-Device tensors diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py index d1354f329ea..0d83a2cbd08 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor +from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -223,7 +223,7 @@ def event_sync(device, record_cq, wait_cq): torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( - torch_weight, layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(mesh_device) + torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(mesh_device) ) # Copy TTNN host tensors into preallocated Mult-Device tensors diff --git a/tests/ttnn/unit_tests/test_reshape.py b/tests/ttnn/unit_tests/test_reshape.py index 40fd7c15052..8d1d87cfd35 100644 --- a/tests/ttnn/unit_tests/test_reshape.py +++ b/tests/ttnn/unit_tests/test_reshape.py @@ -546,7 +546,7 @@ def test_reshape_zero_element(input_shape, output_shape, layout, ttnn_reshape, u ) def test_reshape_replicated_tensor(mesh_device, input_shape, output_shape): torch_input_tensor = torch.randn(input_shape) - mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) tt_input_tensor = ttnn.from_torch( torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, mesh_mapper=mesh_mapper, device=mesh_device ) From 5244781d0226e94f88d57417f3f75223e9ba3e85 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 16:53:19 +0000 Subject: [PATCH 70/76] switch shardtensortomesh --- .../falcon7b_common/tests/test_falcon_mlp.py | 4 +-- .../demos/falcon7b_common/tests/test_utils.py | 20 +++++------ .../demos/falcon7b_common/tt/falcon_model.py | 6 ++-- .../multimodal/test_llama_cross_attention.py | 2 +- ..._llama_cross_attention_transformer_text.py | 2 +- .../multimodal/test_llama_cross_block.py | 4 +-- models/demos/llama3/tt/llama_attention.py | 6 ++-- models/demos/llama3/tt/lm_head.py | 2 +- .../tt/multimodal/llama_conv2d_patch.py | 2 +- .../tt/multimodal/llama_cross_attention.py | 8 ++--- .../llama_cross_attention_transformer_text.py | 4 +-- ...lama_cross_attention_transformer_vision.py | 2 +- .../llama3/tt/multimodal/llama_cross_block.py | 4 +-- .../tt/multimodal/llama_image_attention.py | 4 +-- .../llama3/tt/multimodal/llama_image_mlp.py | 2 +- .../multimodal/llama_positional_embedding.py | 2 +- .../llama_tile_position_embedding.py | 2 +- .../tt/multimodal/llama_vision_model.py | 2 +- models/demos/qwen/tests/test_qwen_rms_norm.py | 2 +- models/demos/qwen/tt/lm_head.py | 2 +- models/demos/qwen/tt/model_config.py | 4 +-- models/demos/qwen/tt/qwen_attention.py | 14 ++++---- models/demos/qwen/tt/qwen_embedding.py | 2 +- models/demos/qwen/tt/qwen_mlp.py | 2 +- models/demos/t3000/falcon40b/demo/demo.py | 4 +-- .../falcon40b/tests/test_falcon_attention.py | 14 ++++---- .../falcon40b/tests/test_falcon_causallm.py | 10 +++--- .../falcon40b/tests/test_falcon_decoder.py | 18 +++++----- .../falcon40b/tests/test_falcon_model.py | 10 +++--- .../tests/test_falcon_prefill_determinism.py | 6 ++-- .../t3000/falcon40b/tt/falcon_attention.py | 6 ++-- .../t3000/falcon40b/tt/falcon_causallm.py | 4 +-- .../t3000/falcon40b/tt/falcon_embeddings.py | 4 +-- models/demos/t3000/falcon40b/tt/falcon_mlp.py | 6 ++-- .../demos/t3000/falcon40b/tt/falcon_model.py | 12 +++---- .../llama2_70b/tests/test_llama_decoder.py | 4 +-- .../llama2_70b/tests/test_llama_generation.py | 2 +- .../tt/llama_attention_optimized.py | 8 ++--- .../llama2_70b/tt/llama_decoder_optimized.py | 8 ++--- .../t3000/llama2_70b/tt/llama_embedding.py | 4 +-- .../llama2_70b/tt/llama_mlp_optimized.py | 8 ++--- .../llama2_70b/tt/llama_model_optimized.py | 8 ++--- .../t3000/mixtral8x7b/tt/mixtral_attention.py | 8 ++--- .../demos/t3000/mixtral8x7b/tt/mixtral_mlp.py | 4 +-- .../demos/t3000/mixtral8x7b/tt/mixtral_moe.py | 6 ++-- .../tests/multi_chip/test_falcon_attention.py | 8 ++--- .../tests/multi_chip/test_falcon_causallm.py | 10 +++--- .../tests/multi_chip/test_falcon_decoder.py | 8 ++--- .../tests/multi_chip/test_falcon_mlp.py | 6 ++-- .../tests/multi_chip/test_falcon_model.py | 6 ++-- models/demos/ttnn_falcon7b/tt/falcon_model.py | 4 +-- .../ttnn_resnet/tests/resnet50_test_infra.py | 2 +- models/demos/wormhole/bert_tiny/demo/demo.py | 4 +-- .../bert_tiny/tests/test_performance.py | 2 +- models/demos/wormhole/distilbert/demo/demo.py | 4 +-- .../distilbert/tests/test_perf_distilbert.py | 2 +- .../tests/test_unet_bottleneck.py | 2 +- .../tests/test_unet_downblock.py | 2 +- .../tests/test_unet_multi_device.py | 2 +- .../functional_unet/tests/test_unet_trace.py | 4 +-- .../tests/test_unet_upblock.py | 2 +- models/experimental/grok/tt/grok_attention.py | 6 ++-- models/experimental/grok/tt/grok_common.py | 2 +- models/experimental/grok/tt/grok_mlp.py | 4 +-- models/experimental/grok/tt/grok_model.py | 2 +- models/experimental/grok/tt/grok_moe.py | 4 +-- tech_reports/CNNs/cnn_optimizations.md | 2 +- tech_reports/LLMs/llms.md | 4 +-- .../Programming_Mesh_of_Devices_with_TT-NN.md | 10 +++--- .../sweeps/ccl/line_all_gather.py | 4 +-- .../distributed/test_data_parallel_example.py | 2 +- .../test_data_parallel_example_TG.py | 2 +- tests/ttnn/distributed/test_multidevice_TG.py | 4 +-- .../test_tensor_parallel_example_T3000.py | 2 +- .../bert_tiny/test_bert_tiny_wh.py | 10 +++--- .../distilbert/test_ttnn_distilbert_wh.py | 2 +- .../operations/ccl/test_all_gather.py | 2 +- .../ccl/test_all_gather_llama_perf_sweep.py | 2 +- .../operations/ccl/test_all_gather_matmul.py | 4 +-- .../operations/ccl/test_all_gather_nightly.py | 4 +-- .../ccl/test_barrier_t3000_frequent.py | 5 ++- .../ccl/test_reduce_scatter_post_commit.py | 2 +- .../unit_tests/operations/test_new_conv2d.py | 2 +- tests/ttnn/unit_tests/test_multi_device.py | 34 +++++++++---------- .../unit_tests/test_multi_device_async.py | 32 ++++++++--------- .../unit_tests/test_multi_device_events.py | 6 ++-- .../unit_tests/test_multi_device_trace.py | 10 +++--- .../unit_tests/test_multi_device_trace_TG.py | 10 +++--- .../unit_tests/test_multi_device_trace_tgg.py | 10 +++--- tests/ttnn/unit_tests/test_sub_device.py | 2 +- ttnn/ttnn/distributed/distributed.py | 4 +-- 91 files changed, 257 insertions(+), 258 deletions(-) diff --git a/models/demos/falcon7b_common/tests/test_falcon_mlp.py b/models/demos/falcon7b_common/tests/test_falcon_mlp.py index cf741ff67d1..6e8f2328eef 100644 --- a/models/demos/falcon7b_common/tests/test_falcon_mlp.py +++ b/models/demos/falcon7b_common/tests/test_falcon_mlp.py @@ -6,7 +6,7 @@ import torch from loguru import logger import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.falcon_mlp import TtFalconMLPDecode, TtFalconMLPPrefill from models.demos.falcon7b_common.tt.model_config import get_model_config from models.demos.falcon7b_common.tests.test_utils import load_hf_model, tt_from_torch, get_num_devices @@ -79,7 +79,7 @@ def run_test_FalconMLP_inference( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_out = tt_FalconMLP_model(tt_mlp_input) diff --git a/models/demos/falcon7b_common/tests/test_utils.py b/models/demos/falcon7b_common/tests/test_utils.py index 076d64500e6..b8bf2caa254 100644 --- a/models/demos/falcon7b_common/tests/test_utils.py +++ b/models/demos/falcon7b_common/tests/test_utils.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from transformers import FalconForCausalLM from models.utility_functions import tt_tensors_to_torch_tensors @@ -106,7 +106,7 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) if model_config["PREFILL_OPTIMIZED_MODE"] and seq_len in [2048, 128, 1024]: @@ -121,7 +121,7 @@ def get_rand_falcon_inputs( dtype=ttnn.bfloat4_b, device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) for attn_mask in attn_masks ] @@ -131,7 +131,7 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) # Generate kvcache for each layer @@ -145,14 +145,14 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_v_cache = tt_from_torch( tt_v_cache.unsqueeze(1), dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) @@ -169,7 +169,7 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=2), ) attention_mask_bool = torch.zeros(global_batch, 1, q_len, kv_len, dtype=bool) @@ -200,7 +200,7 @@ def get_rand_falcon_inputs( device=mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=model_config["ATTN_MASK_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=device_shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=device_shard_dim), ) if not model_config["l1_sharded"]: # Tilize attn masks @@ -227,14 +227,14 @@ def get_rand_falcon_inputs( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_v_cache = tt_from_torch( tt_v_cache.unsqueeze(1), dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/falcon7b_common/tt/falcon_model.py b/models/demos/falcon7b_common/tt/falcon_model.py index b8c9a50423b..83b2a51c278 100644 --- a/models/demos/falcon7b_common/tt/falcon_model.py +++ b/models/demos/falcon7b_common/tt/falcon_model.py @@ -7,7 +7,7 @@ import torch import ttnn -from ttnn import replicate_tensor_to_mesh_mapper, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper from models.demos.falcon7b_common.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.falcon7b_common.tt.model_utils import get_weights_cached, layernorm @@ -177,7 +177,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["INPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), ) elif llm_mode == "decode": assert batch_size % 32 == 0, "For decode, batch_size must be multiple of 32!" @@ -226,7 +226,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["INPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) else: raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.") diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py index d4eedaaf744..7bdd8059769 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention.py @@ -106,7 +106,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) for _ in range(2) ] diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 16b85b2b220..eec8b4f7bd1 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -209,7 +209,7 @@ def test_llama_cross_attention_transformer_text_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), ) rot_mats = get_prefill_rot_mat( diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py index 0d132fb5191..ff8d79180c0 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_block.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_block.py @@ -97,7 +97,7 @@ def test_llama_cross_attention_transformer_block_inference( layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) for _ in range(2) ] @@ -177,7 +177,7 @@ def test_llama_cross_attention_transformer_block_inference( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), ) tt_out = tt_model( tt_tensor_x, diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 9afd6c45738..d9c064c2ddc 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -63,7 +63,7 @@ def __init__( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) user_selection_matrix = torch.eye(8, 8) user_selection_matrix = torch.nn.functional.pad(user_selection_matrix, (0, 24), "constant", 0) # (8, 32) @@ -128,7 +128,7 @@ def __init__( self.wqkv_bias_prefill = ttnn.as_tensor( qkv_bias, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -153,7 +153,7 @@ def __init__( bias_tensor = ttnn.as_tensor( qkv_bias_decode, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/lm_head.py b/models/demos/llama3/tt/lm_head.py index a79f8856e66..628ca3e093d 100644 --- a/models/demos/llama3/tt/lm_head.py +++ b/models/demos/llama3/tt/lm_head.py @@ -87,7 +87,7 @@ def __init__( ttnn.as_tensor( combined_split, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, dtype=dtype, memory_config=memory_config, diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index 2da85f97e33..8017b9174ce 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -11,7 +11,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper class TtLlamaConv2dPatch(LightweightModule): diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index ef312334bcf..e2a8164c695 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -72,7 +72,7 @@ def __init__( self.wq = ttnn.as_tensor( self.state_dict[wq_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -82,7 +82,7 @@ def __init__( self.wk = ttnn.as_tensor( self.state_dict[wk_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -92,7 +92,7 @@ def __init__( self.wv = ttnn.as_tensor( self.state_dict[wv_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -102,7 +102,7 @@ def __init__( self.wo = ttnn.as_tensor( self.state_dict[wo_str].transpose(-2, -1), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.dtype, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 28ee6e810ed..65248a2d619 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -93,7 +93,7 @@ def __init__( lm_head_torch[split], dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, cache_file_name=cache_name(name, suffix, split), @@ -254,7 +254,7 @@ def setup_cache(self, max_batch_size): layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=ttnn.bfloat16, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) for _ in range(2) ] diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py index f06014218fb..f96d39bb3d9 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py @@ -74,7 +74,7 @@ def shuffle_weight(weight): dtype=type, device=self.mesh_device, mesh_mapper=( - ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim) if dim is not None else ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) ), diff --git a/models/demos/llama3/tt/multimodal/llama_cross_block.py b/models/demos/llama3/tt/multimodal/llama_cross_block.py index e09ae041595..5d7ad4620b7 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_block.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_block.py @@ -74,7 +74,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_attn"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -109,7 +109,7 @@ def __init__( state_dict[f"{state_dict_prefix}gate_ffwd"].unsqueeze(0).expand(1, self.hidden_size), dtype=ttnn.bfloat16, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/llama3/tt/multimodal/llama_image_attention.py b/models/demos/llama3/tt/multimodal/llama_image_attention.py index c518793f83e..6721b100732 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_image_attention.py @@ -118,7 +118,7 @@ def pad_head_dim(weight, heads_out=True): dim=-1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=ttnn.DRAM_MEMORY_CONFIG, layout=ttnn.TILE_LAYOUT, @@ -132,7 +132,7 @@ def pad_head_dim(weight, heads_out=True): -1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.dtype, layout=ttnn.TILE_LAYOUT, diff --git a/models/demos/llama3/tt/multimodal/llama_image_mlp.py b/models/demos/llama3/tt/multimodal/llama_image_mlp.py index 2c085b834d4..212e558c8f7 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_mlp.py +++ b/models/demos/llama3/tt/multimodal/llama_image_mlp.py @@ -41,7 +41,7 @@ def __init__( dtype=type, device=self.mesh_device, mesh_mapper=( - ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim) if dim is not None else ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) ), diff --git a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py index 58aab0c4157..582904585ed 100644 --- a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper TILE_SIZE = 32 diff --git a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py index 8a1a9d44064..1b3171164d5 100644 --- a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper class TtLlamaTilePositionEmbedding(LightweightModule): diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index 7b0096d1b3c..a1632100a07 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -345,7 +345,7 @@ def prepare_inputs_prefill( dtype=ttnn.bfloat4_b, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), ) if isinstance(cross_page_table, torch.Tensor): diff --git a/models/demos/qwen/tests/test_qwen_rms_norm.py b/models/demos/qwen/tests/test_qwen_rms_norm.py index e5e482e7e04..a7c64249f05 100644 --- a/models/demos/qwen/tests/test_qwen_rms_norm.py +++ b/models/demos/qwen/tests/test_qwen_rms_norm.py @@ -74,7 +74,7 @@ def test_qwen_rms_norm_inference(mesh_device, use_program_cache, reset_seeds, en device=mesh_device, dtype=dtype, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), memory_config=ttnn.L1_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG, ) diff --git a/models/demos/qwen/tt/lm_head.py b/models/demos/qwen/tt/lm_head.py index 84bfb0043a1..9348c137290 100644 --- a/models/demos/qwen/tt/lm_head.py +++ b/models/demos/qwen/tt/lm_head.py @@ -63,7 +63,7 @@ def __init__( ttnn.as_tensor( combined_split, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, dtype=dtype, memory_config=memory_config, diff --git a/models/demos/qwen/tt/model_config.py b/models/demos/qwen/tt/model_config.py index 09f5fbf98fb..102e537685d 100644 --- a/models/demos/qwen/tt/model_config.py +++ b/models/demos/qwen/tt/model_config.py @@ -505,7 +505,7 @@ def prepare_inputs_ttnn_decode(self, x, input_mem_cfg, force_replicated=False): mesh_mapper = ( ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) if force_replicated - else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1) + else ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1) ) if len(x.shape) == 3: @@ -561,7 +561,7 @@ def prepare_inputs_ttnn_prefill(self, x_bsh, force_replicated=False): mesh_mapper = ( ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device) if force_replicated - else ttnn.ShardTensorToMesh(self.mesh_device, dim=-1) + else ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1) ) # input goes to DRAM diff --git a/models/demos/qwen/tt/qwen_attention.py b/models/demos/qwen/tt/qwen_attention.py index 6ef253cf8a4..d3e56f1f921 100644 --- a/models/demos/qwen/tt/qwen_attention.py +++ b/models/demos/qwen/tt/qwen_attention.py @@ -25,7 +25,7 @@ def fall_back_rope(xq, xk, rot_mats, mesh_device): xq = ttnn.from_torch( xq, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -33,7 +33,7 @@ def fall_back_rope(xq, xk, rot_mats, mesh_device): xk = ttnn.from_torch( xk, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -135,7 +135,7 @@ def __init__( dim=-1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=wqkv_mem_config, layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -152,7 +152,7 @@ def __init__( dim=-1, ).unsqueeze(0), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=self.model_config["ATTN_BIAS_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_B_LAYOUT_TILE"], @@ -174,7 +174,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), cache_file_name=cache_name("wo_width_sharded"), ) self.wo = ttnn.to_device(wo_ttnn, self.mesh_device) @@ -190,7 +190,7 @@ def __init__( -1, ), device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), memory_config=wo_mem_config, dtype=self.dtype, layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -236,7 +236,7 @@ def __init__( ttnn.as_tensor( k_or_v, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), layout=self.model_config["ATTN_W_LAYOUT_TILE"], dtype=self.dtype, cache_file_name=f"{weight_cache_path}/kvcache_{k_or_v.shape}" diff --git a/models/demos/qwen/tt/qwen_embedding.py b/models/demos/qwen/tt/qwen_embedding.py index 9cefdf8af90..ad9a0f10a67 100644 --- a/models/demos/qwen/tt/qwen_embedding.py +++ b/models/demos/qwen/tt/qwen_embedding.py @@ -28,7 +28,7 @@ def __init__( torch_weight, dtype=dtype, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=args.get_model_config()["EMB_WEIGHTS_MEMCFG"], cache_file_name=cache_name, diff --git a/models/demos/qwen/tt/qwen_mlp.py b/models/demos/qwen/tt/qwen_mlp.py index ad500853920..ca9976166e0 100644 --- a/models/demos/qwen/tt/qwen_mlp.py +++ b/models/demos/qwen/tt/qwen_mlp.py @@ -38,7 +38,7 @@ def __init__( torch_weight(name_dict[name[:2]]), # Grab only the wX part of the name dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=dim), layout=ttnn.TILE_LAYOUT, memory_config=w2_mem_config if "w2" in name else w1_w3_mem_config, cache_file_name=cache_name(name), diff --git a/models/demos/t3000/falcon40b/demo/demo.py b/models/demos/t3000/falcon40b/demo/demo.py index 3e53c1b0a8e..d7aa80dea38 100644 --- a/models/demos/t3000/falcon40b/demo/demo.py +++ b/models/demos/t3000/falcon40b/demo/demo.py @@ -130,7 +130,7 @@ def initialize_and_fill_kv_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -138,7 +138,7 @@ def initialize_and_fill_kv_cache( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) kv_cache += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py index c7d142c6a9f..55bff77d105 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_attention.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_attention.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -124,7 +124,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( @@ -133,7 +133,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) @@ -161,7 +161,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["LN_ATTN_OUTPUT_MEMCFG"], - ttnn.replicate_tensor_to_mesh_mapper(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), preprocess=lambda x: x.unsqueeze(1).transpose(0, 2), ) @@ -185,7 +185,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, configuration.num_attention_heads, -1, -1), ) @@ -200,7 +200,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -208,7 +208,7 @@ def run_test_FalconAttention_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py b/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py index a853b74a827..f79ce843b72 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_causallm.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -101,7 +101,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_kv_cache_host, @@ -109,7 +109,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) @@ -141,7 +141,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -149,7 +149,7 @@ def run_test_FalconCausalLM_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py index ef66249a132..76b832cc29b 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -93,7 +93,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), ) attention_mask_memconfig = model_config["ATTN_MASK_MEMCFG"] @@ -108,7 +108,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), preprocess=lambda x: (x * -1e5).expand(-1, mesh_device.get_num_devices(), -1, -1), ) @@ -121,7 +121,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -129,7 +129,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) @@ -167,7 +167,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=(mesh_device, dim=-1), preprocess=lambda x: x.unsqueeze(1).transpose(0, 2), ) @@ -192,7 +192,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=(mesh_device, dim=1), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, configuration.num_attention_heads, -1, -1), ) @@ -207,7 +207,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( @@ -216,7 +216,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_model.py b/models/demos/t3000/falcon40b/tests/test_falcon_model.py index 3696d037bbb..e6cffb88639 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_model.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_model.py @@ -6,7 +6,7 @@ import pytest from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import ( FalconForCausalLM, ) @@ -95,7 +95,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_kv_cache_host, @@ -103,7 +103,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) @@ -136,7 +136,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -144,7 +144,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py index 2f023c7eb04..82c070d701e 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from models.demos.t3000.falcon40b.reference.hf_modeling_falcon import FalconForCausalLM, FalconConfig from models.demos.t3000.falcon40b.tt.falcon_causallm import TtFalconCausalLM from models.demos.t3000.falcon40b.tt.model_config import get_model_config, model_config_entries @@ -68,7 +68,7 @@ def run_test_falcon_prefill_end_to_end_determinism( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, @@ -76,7 +76,7 @@ def run_test_falcon_prefill_end_to_end_determinism( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past += ((tt_k_cache, tt_v_cache),) diff --git a/models/demos/t3000/falcon40b/tt/falcon_attention.py b/models/demos/t3000/falcon40b/tt/falcon_attention.py index e4f0385b611..f6afd108e5e 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_attention.py +++ b/models/demos/t3000/falcon40b/tt/falcon_attention.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 from models.demos.t3000.falcon40b.tt.model_utils import convert_to_layout @@ -165,7 +165,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["FUSED_QKV_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), cache_file_name=query_key_value_path, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) @@ -178,7 +178,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["SELFOUT_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), cache_file_name=selfout_path, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_causallm.py b/models/demos/t3000/falcon40b/tt/falcon_causallm.py index 9f971d2e988..ffb8b1a1e60 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_causallm.py +++ b/models/demos/t3000/falcon40b/tt/falcon_causallm.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.falcon_model import TtFalconModelShared from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation @@ -50,7 +50,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=self.model_config["LM_HEAD_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), cache_file_name=lm_head_path, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_embeddings.py b/models/demos/t3000/falcon40b/tt/falcon_embeddings.py index 8135a41b1a5..d0a1e88e0c7 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_embeddings.py +++ b/models/demos/t3000/falcon40b/tt/falcon_embeddings.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper class TtFalconEmbeddings(torch.nn.Module): @@ -25,7 +25,7 @@ def __init__(self, mesh_device, state_dict, cache_path, model_config): device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, cache_file_name=cache_path / base_name, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), preprocess=lambda x: x.reshape(1, 1, *x.shape), ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index 5ed3b36bf9a..b1c90745bfe 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -8,7 +8,7 @@ from typing import List from models.demos.t3000.falcon40b.tt.model_utils import falcon_prefill_matmul, determine_tensor_deallocation -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper class TtFalconMLP: @@ -43,7 +43,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DENSE_H_TO_4H_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=tt_cache_path / dense_h_to_4h_str, preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) @@ -54,7 +54,7 @@ def __init__( layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DENSE_4H_TO_H_MM_WEIGHTS_MEMCFG"], - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=tt_cache_path / f"{dense_4h_to_h_str}_height_fractured", preprocess=lambda x: torch.transpose(x.reshape(1, 1, *x.shape), -2, -1), ) diff --git a/models/demos/t3000/falcon40b/tt/falcon_model.py b/models/demos/t3000/falcon40b/tt/falcon_model.py index c6a92cf1e4e..ac601a2650e 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_model.py +++ b/models/demos/t3000/falcon40b/tt/falcon_model.py @@ -9,7 +9,7 @@ import ttnn -from ttnn import replicate_tensor_to_mesh_mapper, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.t3000.falcon40b.tt.falcon_embeddings import TtFalconEmbeddings from models.demos.t3000.falcon40b.tt.falcon_attention import generate_cos_sin_cache @@ -107,7 +107,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=self.model_config["LN_F_WEIGHTS_MEMCFG"], - ttnn.replicate_tensor_to_mesh_mapper(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=layernorm_weights_path, preprocess=lambda x: x.reshape(1, 1, -1, 32), ) @@ -118,7 +118,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=self.model_config["LN_F_BIAS_MEMCFG"], - ttnn.replicate_tensor_to_mesh_mapper(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), cache_file_name=layernorm_bias_path, preprocess=lambda x: x.reshape(1, 1, -1, 32), ) @@ -138,7 +138,7 @@ def create_attn_mask(self, max_seq_len): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=attention_mask_memconfig, - ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), preprocess=lambda x: (x * -1e5), ) @@ -181,7 +181,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), ) # Generate input and attention_mask --------------------------------------------- @@ -230,7 +230,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["DEFAULT_MEMCFG"], - ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(1, 1, -1, -1), ) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py index 5e94c4ddfc5..f1e738f011a 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_decoder.py @@ -119,7 +119,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(llama_decoder_model.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(llama_decoder_model.mesh_device, dim=3), device=llama_decoder_model.mesh_device, ) xs = ttnn.to_device(xs, llama_decoder_model.mesh_device) @@ -171,7 +171,7 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode, rop dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ttnn.ShardTensorToMesh(llama_decoder_model.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(llama_decoder_model.mesh_device, dim=3), device=llama_decoder_model.mesh_device, ) xs = ttnn.to_device(xs, llama_decoder_model.mesh_device) diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index 2c1b66ceaa2..d6ca6ab4bfe 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -6,7 +6,7 @@ import torch from torch import nn import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor import scipy diff --git a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py index f43779eafdf..dd11900e29d 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_attention_optimized.py @@ -6,7 +6,7 @@ import math import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.demos.t3000.falcon40b.tt.model_utils import matmul_2d_config_from_tensor_shapes @@ -110,7 +110,7 @@ def init_kv_cache(self): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.kv_dtype, @@ -179,7 +179,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / wqkv_cache_str, ) self.qkv = ttnn.to_device(qkv_ttnn, self.mesh_device) @@ -190,7 +190,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / wo_str, ) diff --git a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py index 46b8795ce9c..3e38d525dac 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_decoder_optimized.py @@ -6,7 +6,7 @@ from typing import List import torch import ttnn -from ttnn import replicate_tensor_to_mesh_mapper, ShardTensorToMesh +from ttnn import replicate_tensor_to_mesh_mapper, shard_tensor_to_mesh_mapper from models.demos.t3000.llama2_70b.tt.llama_attention_optimized import TtLlamaAttention_optimized from models.demos.t3000.llama2_70b.tt.llama_mlp_optimized import TtLlamaMLP_optimized @@ -117,7 +117,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / attn_norm_sharded_str, ) self.attn_norm_sharded = ttnn.to_device(attn_norm_sharded_ttnn, self.mesh_device) @@ -128,7 +128,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / ffn_norm_str, ) self.ffn_norm = ttnn.to_device(ffn_norm_ttnn, self.mesh_device) @@ -139,7 +139,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / ffn_norm_sharded_str, ) self.ffn_norm_sharded = ttnn.to_device(ffn_norm_sharded_ttnn, self.mesh_device) diff --git a/models/demos/t3000/llama2_70b/tt/llama_embedding.py b/models/demos/t3000/llama2_70b/tt/llama_embedding.py index 177cfa7e293..8211edb21c0 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_embedding.py +++ b/models/demos/t3000/llama2_70b/tt/llama_embedding.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper class TtLlamaEmbedding: @@ -44,7 +44,7 @@ def __init__( layout=ttnn.ROW_MAJOR_LAYOUT, device=mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), cache_file_name=cache_path / base_name, ) self.emb_weights = ttnn.to_device(embd_weights_ttn, mesh_device) diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index aa0d5ae2a24..4356e5dfdaa 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -6,7 +6,7 @@ from typing import List import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.utility_functions import nearest_32 from models.demos.t3000.falcon40b.tt.model_utils import matmul_2d_config @@ -89,7 +89,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=w3_mem_config, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / w1_dram_shard_str, ) @@ -105,7 +105,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=w2_memory_config, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / w2_dram_shard_str, ) @@ -115,7 +115,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=w3_mem_config, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / w3_dram_shard_str, ) diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index ceac83f3417..29b59aa5490 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -7,7 +7,7 @@ from tqdm import tqdm import torch import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32, profiler @@ -139,7 +139,7 @@ def load_weights(self): layout=ttnn.TILE_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=3), cache_file_name=self.cache_path / lm_head_str, ) self.lm_head = ttnn.to_device(padded_lm_head_ttnn, self.mesh_device) @@ -150,7 +150,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(self.mesh_device), cache_file_name=self.cache_path / norm_str, ) self.norm = ttnn.to_device(norm_ttnn, self.mesh_device) @@ -161,7 +161,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=2), cache_file_name=self.cache_path / norm_sharded_str, ) self.norm_sharded = ttnn.to_device(norm_sharded_ttnn, self.mesh_device) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 641c42dd847..c768c940384 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -5,7 +5,7 @@ import torch import ttnn from models.utility_functions import nearest_32 -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.common.lightweightmodule import LightweightModule @@ -74,7 +74,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -90,7 +90,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-2), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-2), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -118,7 +118,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), dtype=ttnn.bfloat8_b, layout=self.model_config["ATTN_W_LAYOUT_TILE"], memory_config=self.model_config["ATTN_CACHE_WEIGHTS_MEMCFG"], diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py index c1272b8e62a..496ebd7346c 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.common.lightweightmodule import LightweightModule @@ -36,7 +36,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtypes): torch_weight(name), dtype=dtypes[name], device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), layout=self.model_config["MLP_W_LAYOUT_TILE"], memory_config=self.model_config["MLP_WEIGHTS_MEMCFG"], cache_file_name=cache_name(name), diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py index 1fc4d945045..8aaeeffee58 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.common.lightweightmodule import LightweightModule @@ -44,7 +44,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): memory_config=self.model_config["GATE_WEIGHTS_MEMCFG"], cache_file_name=cache_name, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) self.tile_size = 32 @@ -58,7 +58,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - ttnn.replicate_tensor_to_mesh_mapper(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) self.top8_mask_11B_64 = ttnn.sum(self.top8_mask_11B_64, dim=2) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index edbb7a18da1..f55c4d3e784 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -21,7 +21,7 @@ import transformers from loguru import logger -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -104,7 +104,7 @@ def test_falcon_attention( seq_len, configuration.hidden_size, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) position_ids = create_position_ids(llm_mode, kv_cache_len) attention_mask, tt_attention_mask = create_attention_mask( @@ -116,7 +116,7 @@ def test_falcon_attention( configuration.num_attention_heads, kv_cache_len, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) layer_past, tt_layer_past = create_kv_cache( llm_mode, @@ -125,7 +125,7 @@ def test_falcon_attention( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) pytorch_out, pytorch_layer_present = torch_model( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py index c74076bd75e..267259df1ab 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py @@ -22,7 +22,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -111,7 +111,7 @@ def test_falcon_causal_lm( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += (tt_current_layer_past,) attention_mask = None @@ -127,7 +127,7 @@ def test_falcon_causal_lm( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) past_key_values += (current_layer_past,) tt_layer_past += (tt_current_layer_past,) @@ -327,7 +327,7 @@ def test_t3k_falcon_causal_lm_with_trace( kv_cache_len, configuration, t3k_mesh_device, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0), ) tt_layer_past += (tt_current_layer_past,) attention_mask = None @@ -343,7 +343,7 @@ def test_t3k_falcon_causal_lm_with_trace( kv_cache_len, configuration, t3k_mesh_device, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0), ) past_key_values += (current_layer_past,) tt_layer_past += (tt_current_layer_past,) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py index ea7aad0cc7f..bb247fd88a4 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py @@ -21,7 +21,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -102,7 +102,7 @@ def test_falcon_decoder( seq_len, configuration.hidden_size, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) position_ids = create_position_ids(llm_mode, kv_cache_len) attention_mask, tt_attention_mask = create_attention_mask( @@ -114,7 +114,7 @@ def test_falcon_decoder( configuration.num_attention_heads, kv_cache_len, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=shard_dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=shard_dim), ) layer_past, tt_layer_past = create_kv_cache( llm_mode, @@ -123,7 +123,7 @@ def test_falcon_decoder( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) pytorch_out, pytorch_layer_present = torch_model( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index 5edb9f55ef5..af8e7a6beda 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -11,7 +11,7 @@ from models.demos.ttnn_falcon7b.tt.common import create_custom_preprocessor, strip_state_dict_prefix from ttnn.model_preprocessing import preprocess_model_parameters from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor import transformers from loguru import logger @@ -88,7 +88,7 @@ def test_falcon_mlp( tt_cache_path=get_tt_cache_path(f"{model_name}"), device=mesh_device, base_file_name=get_model_prefix(), - weights_ttnn.replicate_tensor_to_mesh_mapper(mesh_device), + weights_mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), ), ) @@ -98,7 +98,7 @@ def test_falcon_mlp( dtype=model_config["DEFAULT_DTYPE"], device=mesh_device, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) ttnn_output = ttnn_model(ttnn_input) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py index 8786d1d722e..ebdf747aa5d 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_model.py @@ -23,7 +23,7 @@ ) from loguru import logger -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor PRETRAINED_MODEL_NAME = f"tiiuae/falcon-7b-instruct" @@ -114,7 +114,7 @@ def test_falcon_model( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) tt_layer_past += (tt_current_layer_past,) attention_mask = None @@ -130,7 +130,7 @@ def test_falcon_model( kv_cache_len, configuration, mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) past_key_values += (current_layer_past,) tt_layer_past += (tt_current_layer_past,) diff --git a/models/demos/ttnn_falcon7b/tt/falcon_model.py b/models/demos/ttnn_falcon7b/tt/falcon_model.py index a54d8e36242..08ad282b6fe 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_model.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_model.py @@ -10,7 +10,7 @@ from models.demos.ttnn_falcon7b.tt.falcon_decoder import TtFalconDecoderLayer from models.demos.ttnn_falcon7b.tt.common import create_attention_mask -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor class TtFalconModelShared: @@ -58,7 +58,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token mesh_mapper = None else: shard_dim = 2 if llm_mode == "decode" else 0 - mesh_mapper = ShardTensorToMesh(self.device, dim=shard_dim) + mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(self.device, dim=shard_dim) # Generate input and attention_mask --------------------------------------------- if llm_mode == "prefill": diff --git a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py index 559525e51a6..842a3916666 100644 --- a/models/demos/ttnn_resnet/tests/resnet50_test_infra.py +++ b/models/demos/ttnn_resnet/tests/resnet50_test_infra.py @@ -260,7 +260,7 @@ def __init__( def get_mesh_mappers(self, device): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: - inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(device, dim=0) weights_mesh_mapper = None # ttnn.replicate_tensor_to_mesh_mapper(device) causes unnecessary replication/takes more time on the first pass output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) else: diff --git a/models/demos/wormhole/bert_tiny/demo/demo.py b/models/demos/wormhole/bert_tiny/demo/demo.py index 05f92393448..0d4834d3cd7 100644 --- a/models/demos/wormhole/bert_tiny/demo/demo.py +++ b/models/demos/wormhole/bert_tiny/demo/demo.py @@ -70,7 +70,7 @@ def run_bert_question_and_answering_inference( profiler.start(f"preprocessing_parameter") mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): @@ -189,7 +189,7 @@ def run_bert_question_and_answering_inference_squad_v2( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( diff --git a/models/demos/wormhole/bert_tiny/tests/test_performance.py b/models/demos/wormhole/bert_tiny/tests/test_performance.py index 5e9029ce84b..92aa50ebc60 100644 --- a/models/demos/wormhole/bert_tiny/tests/test_performance.py +++ b/models/demos/wormhole/bert_tiny/tests/test_performance.py @@ -52,7 +52,7 @@ def test_perf_bert_tiny( torch_position_ids = torch.zeros((batch_size, sequence_size), dtype=torch.int32) torch_attention_mask = torch.zeros(1, sequence_size) - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): parameters = preprocess_model_parameters( diff --git a/models/demos/wormhole/distilbert/demo/demo.py b/models/demos/wormhole/distilbert/demo/demo.py index 51ba895798e..71b72c9e295 100644 --- a/models/demos/wormhole/distilbert/demo/demo.py +++ b/models/demos/wormhole/distilbert/demo/demo.py @@ -50,7 +50,7 @@ def run_distilbert_question_and_answering_inference( HF_model.eval() tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) @@ -191,7 +191,7 @@ def run_distilbert_question_and_answering_inference_squad_v2( tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py index 77c0232a765..855562661a7 100644 --- a/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py +++ b/models/demos/wormhole/distilbert/tests/test_perf_distilbert.py @@ -67,7 +67,7 @@ def test_performance_distilbert_for_qa( ) tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) profiler.start(f"preprocessing_parameter") diff --git a/models/experimental/functional_unet/tests/test_unet_bottleneck.py b/models/experimental/functional_unet/tests/test_unet_bottleneck.py index ce94ef1ccaa..80983c87af5 100644 --- a/models/experimental/functional_unet/tests/test_unet_bottleneck.py +++ b/models/experimental/functional_unet/tests/test_unet_bottleneck.py @@ -52,7 +52,7 @@ def test_unet_bottleneck_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/models/experimental/functional_unet/tests/test_unet_downblock.py b/models/experimental/functional_unet/tests/test_unet_downblock.py index 69231d68504..6e8c4017dc1 100644 --- a/models/experimental/functional_unet/tests/test_unet_downblock.py +++ b/models/experimental/functional_unet/tests/test_unet_downblock.py @@ -81,7 +81,7 @@ def test_unet_downblock_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/models/experimental/functional_unet/tests/test_unet_multi_device.py b/models/experimental/functional_unet/tests/test_unet_multi_device.py index c21759f72a6..f886a24c8ff 100644 --- a/models/experimental/functional_unet/tests/test_unet_multi_device.py +++ b/models/experimental/functional_unet/tests/test_unet_multi_device.py @@ -28,7 +28,7 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache, if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300 or T3000") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/models/experimental/functional_unet/tests/test_unet_trace.py b/models/experimental/functional_unet/tests/test_unet_trace.py index 74045f761d5..0e484be93f4 100644 --- a/models/experimental/functional_unet/tests/test_unet_trace.py +++ b/models/experimental/functional_unet/tests/test_unet_trace.py @@ -231,7 +231,7 @@ def test_unet_trace_2cq_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300 or T3000") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) @@ -489,7 +489,7 @@ def test_unet_trace_2cq_same_io_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device) and not is_t3k_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300 or T3000") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/models/experimental/functional_unet/tests/test_unet_upblock.py b/models/experimental/functional_unet/tests/test_unet_upblock.py index 15566539c22..dda519db358 100644 --- a/models/experimental/functional_unet/tests/test_unet_upblock.py +++ b/models/experimental/functional_unet/tests/test_unet_upblock.py @@ -98,7 +98,7 @@ def test_unet_upblock_multi_device( if not is_n300_with_eth_dispatch_cores(mesh_device): pytest.skip("Test is only valid for N300") - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/models/experimental/grok/tt/grok_attention.py b/models/experimental/grok/tt/grok_attention.py index 38fbdca652f..7b5d8158d6f 100644 --- a/models/experimental/grok/tt/grok_attention.py +++ b/models/experimental/grok/tt/grok_attention.py @@ -5,7 +5,7 @@ import torch import ttnn from models.utility_functions import nearest_32 -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor from models.experimental.grok.tt.grok_common import LightweightModule @@ -75,7 +75,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=-1), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], @@ -119,7 +119,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtype): ttnn.as_tensor( lp, device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), dtype=ttnn.bfloat8_b, layout=self.model_config["ATTN_W_LAYOUT_TILE"], memory_config=self.model_config["ATTN_CACHE_WEIGHTS_MEMCFG"], diff --git a/models/experimental/grok/tt/grok_common.py b/models/experimental/grok/tt/grok_common.py index ade352ee7d1..b6e05a0a84f 100644 --- a/models/experimental/grok/tt/grok_common.py +++ b/models/experimental/grok/tt/grok_common.py @@ -5,7 +5,7 @@ from loguru import logger import torch import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import nearest_32 diff --git a/models/experimental/grok/tt/grok_mlp.py b/models/experimental/grok/tt/grok_mlp.py index db3cea55d8a..9771e835fa0 100644 --- a/models/experimental/grok/tt/grok_mlp.py +++ b/models/experimental/grok/tt/grok_mlp.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule @@ -36,7 +36,7 @@ def __init__(self, mesh_device, state_dict, args, layer_num, dtypes): torch_weight(name), dtype=dtypes[name], device=self.mesh_device, - mesh_mapper=ShardTensorToMesh(self.mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), layout=self.model_config["MLP_W_LAYOUT_TILE"], memory_config=self.model_config["MLP_WEIGHTS_MEMCFG"], cache_file_name=cache_name(name), diff --git a/models/experimental/grok/tt/grok_model.py b/models/experimental/grok/tt/grok_model.py index 98e7c18b0b5..8a97ebf400b 100644 --- a/models/experimental/grok/tt/grok_model.py +++ b/models/experimental/grok/tt/grok_model.py @@ -61,7 +61,7 @@ def __init__( dtype=ttnn.bfloat16, memory_config=self.model_config["OUTPUT_WEIGHTS_MEMCFG"], cache_file_name=output_cache_name, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=-1), ) self.compute_kernel = self.args.get_compute_kernel_output_config() diff --git a/models/experimental/grok/tt/grok_moe.py b/models/experimental/grok/tt/grok_moe.py index e40c00e08e2..ede5b375701 100644 --- a/models/experimental/grok/tt/grok_moe.py +++ b/models/experimental/grok/tt/grok_moe.py @@ -4,7 +4,7 @@ import torch import ttnn -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper from models.experimental.grok.tt.grok_common import LightweightModule from models.experimental.grok.scripts.tlog import tlog, tlog_mesh_device @@ -55,7 +55,7 @@ def __init__(self, mesh_device, state_dict, experts, args, layer_num, dtype): dtype=ttnn.uint16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) top8_mask = torch.full((1, 1, 32, 64), fill_value=torch.finfo(torch.float).min) top8_mask[:, :, :, 1:9] = 0.0 diff --git a/tech_reports/CNNs/cnn_optimizations.md b/tech_reports/CNNs/cnn_optimizations.md index 49b656ed21e..3d3dfebd36d 100644 --- a/tech_reports/CNNs/cnn_optimizations.md +++ b/tech_reports/CNNs/cnn_optimizations.md @@ -196,7 +196,7 @@ Combining these two features should For more details on tracing and multi-CQs, c Throughput can be improved if multiple chips are availible by replicating the CNN across each chip. For our UNet model, we replicate across the outermost dimension: ```python -inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) # Shard input tensor on dimension 0 across each device +inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) # Shard input tensor on dimension 0 across each device weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) # Replicate weights across all devices output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) # Map multi-device tensor back to single host tensor ``` diff --git a/tech_reports/LLMs/llms.md b/tech_reports/LLMs/llms.md index 1334618d9c4..bf49b9d5326 100644 --- a/tech_reports/LLMs/llms.md +++ b/tech_reports/LLMs/llms.md @@ -913,7 +913,7 @@ for i, split_size in enumerate(split_sizes): ttnn.as_tensor( combined_split, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), layout=ttnn.TILE_LAYOUT, dtype=dtype, memory_config=memory_config, @@ -1206,7 +1206,7 @@ mesh_tensor_sharded = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) # Convert to ttnn.Tensor, tilize and move onto mesh_device (2x4 devices) by replication diff --git a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md index 760d78837ea..28de54fa44f 100644 --- a/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md +++ b/tech_reports/Programming_Mesh_of_Devices/Programming_Mesh_of_Devices_with_TT-NN.md @@ -140,7 +140,7 @@ torch_tensor[..., 32:64] = 2.0 # Convert to ttnn.Tensor; MeshTensor holds buffers to two shards in host-memory mesh_tensor = ttnn.from_torch( torch_tensor, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), layout=ttnn.TILE_LAYOUT, ) ``` @@ -306,7 +306,7 @@ mesh_tensor = ttnn.from_torch( torch_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) # Execute All-Gather on the tensor; `num_links=1` specifies the number of ethernet links to use @@ -338,7 +338,7 @@ mesh_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) # Execute Line All-Gather on the tensor @@ -452,7 +452,7 @@ torch_output = model.forward(torch_hidden_states) mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(y=1, x=4)) # Shard input activations on batch dimension to devices in the mesh -with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=0)): +with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, @@ -548,7 +548,7 @@ with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): ) # Shard model parameters on width dimension to devices in the mesh -with ttnn.distribute(ttnn.ShardTensorToMesh(t3k_mesh_device, dim=-1)): +with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=-1)): parameters = ttnn.model_preprocessing.preprocess_model_parameters( initialize_model=lambda: model, device=t3k_mesh_device, diff --git a/tests/sweep_framework/sweeps/ccl/line_all_gather.py b/tests/sweep_framework/sweeps/ccl/line_all_gather.py index b30cd0f9f1e..0440aa17d64 100644 --- a/tests/sweep_framework/sweeps/ccl/line_all_gather.py +++ b/tests/sweep_framework/sweeps/ccl/line_all_gather.py @@ -12,7 +12,7 @@ from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from tests.ttnn.unit_tests.operations.ccl.test_all_gather import is_unsupported_case -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper # Override the default timeout in seconds for hang detection. TIMEOUT = 30 @@ -104,7 +104,7 @@ def run( input_tensor = torch.rand(input_shape).bfloat16() ttnn_tensor = ttnn.from_torch( - input_tensor, tile=ttnn.Tile(tile), mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim) + input_tensor, tile=ttnn.Tile(tile), ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=dim) ) input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_mesh_device) diff --git a/tests/ttnn/distributed/test_data_parallel_example.py b/tests/ttnn/distributed/test_data_parallel_example.py index 2af396f4d7d..cd34afa2572 100644 --- a/tests/ttnn/distributed/test_data_parallel_example.py +++ b/tests/ttnn/distributed/test_data_parallel_example.py @@ -37,7 +37,7 @@ def test_data_parallel_falcon_mlp(mesh_device): torch_output = model.forward(torch_hidden_states) # Shard input activations on batch dimension to devices in the mesh - with ttnn.distribute(mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0)): + with ttnn.distribute(mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, diff --git a/tests/ttnn/distributed/test_data_parallel_example_TG.py b/tests/ttnn/distributed/test_data_parallel_example_TG.py index ae45f63b467..35e6a6f699e 100644 --- a/tests/ttnn/distributed/test_data_parallel_example_TG.py +++ b/tests/ttnn/distributed/test_data_parallel_example_TG.py @@ -39,7 +39,7 @@ def test_data_parallel_falcon_mlp(mesh_device): torch_output = model.forward(torch_hidden_states) # Shard input activations on batch dimension to devices in the mesh - with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=0)): + with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0)): hidden_states = ttnn.from_torch( torch_hidden_states, dtype=ttnn.bfloat16, diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index a98aaa9f540..aec169c5668 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -11,7 +11,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc from ttnn import ( - ShardTensorToMesh, + shard_tensor_to_mesh_mapper, ShardTensor2dMesh, ttnn.replicate_tensor_to_mesh_mapper, ConcatMeshToTensor, @@ -46,7 +46,7 @@ def test_galaxy_matmul_1d_fracture(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) gt = act_pt @ weights_pt diff --git a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py index 72a81cd766a..1faa8724328 100644 --- a/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py +++ b/tests/ttnn/distributed/test_tensor_parallel_example_T3000.py @@ -62,7 +62,7 @@ def test_tensor_parallel_falcon_mlp(): ) # Shard model parameters on width dimension to devices in the mesh - with ttnn.distribute(ttnn.ShardTensorToMesh(mesh_device, dim=-1)): + with ttnn.distribute(ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1)): parameters = ttnn.model_preprocessing.preprocess_model_parameters( initialize_model=lambda: model, device=mesh_device, diff --git a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py index b4a023a1235..34d9d72e7fc 100644 --- a/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py +++ b/tests/ttnn/integration_tests/bert_tiny/test_bert_tiny_wh.py @@ -33,7 +33,7 @@ def test_bert_attention_inference( config = hugging_face_reference_model.config mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): @@ -90,7 +90,7 @@ def test_bert_intermediate_inference( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): @@ -137,7 +137,7 @@ def test_bert_output_inference( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): @@ -194,7 +194,7 @@ def test_bert_layer_inference( mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): @@ -244,7 +244,7 @@ def test_bert_for_question_answering(mesh_device, model_name, sequence_size, num mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 batch_size = 16 if mesh_device_flag else 8 - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.replicate_tensor_to_mesh_mapper(mesh_device)): diff --git a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py index 2cea87ab1d8..ecf85b2ec18 100644 --- a/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py +++ b/tests/ttnn/integration_tests/distilbert/test_ttnn_distilbert_wh.py @@ -28,7 +28,7 @@ def test_distilbert_for_question_answering(mesh_device, model_name, batch_size, tt_model_name = f"ttnn_{model_name}_optimized" - inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) weights_mesh_mapper = ttnn.replicate_tensor_to_mesh_mapper(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py index 2a42df95821..226288832fd 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather.py @@ -151,7 +151,7 @@ def run_all_gather_impl( dtype=input_dtype, layout=layout, tile=ttnn.Tile(tile), - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim), device=mesh_device, ) if trace_mode: diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py index 4357e6996d9..0c76cb164cc 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_llama_perf_sweep.py @@ -9,7 +9,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 import itertools -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapper from tests.ttnn.unit_tests.operations.ccl.test_all_gather import run_all_gather_sharded diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py index 74f6af3b4e5..7ba28c6d07e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_matmul.py @@ -6,7 +6,7 @@ import pytest from loguru import logger import ttnn -from ttnn import ShardTensorToMesh, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull, skip_for_wormhole_b0 from tests.ttnn.unit_tests.operations.ccl.test_all_gather import is_unsupported_case @@ -73,7 +73,7 @@ def run_all_gather_matmul_on_t3000_impl( layout=layout, device=t3k_mesh_device, memory_config=mem_config_weights, - mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=dim), tile=ttnn.Tile(tile), ) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py index 503c33121a2..e30c48dd5af 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py @@ -12,7 +12,7 @@ is_unsupported_case, run_all_gather_on_t3000_impl, ) -from ttnn import ShardTensorToMesh +from ttnn import shard_tensor_to_mesh_mapperesh_mapper # Enumerate the post-commit cases explicitly @@ -186,7 +186,7 @@ def run_line_all_gather_instances( input_tensor = torch.rand(input_shape).bfloat16() ttnn_tensor = ttnn.from_torch( - input_tensor, tile=ttnn.Tile(tile), mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=dim) + input_tensor, tile=ttnn.Tile(tile), ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=dim) ) input_tensor_mesh = ttnn.to_device(ttnn_tensor, t3k_mesh_device) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py b/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py index d50ab25bce3..e8dd087d82d 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_barrier_t3000_frequent.py @@ -9,7 +9,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.ccl.test_all_gather import is_unsupported_case_t3k -from ttnn.distributed.distributed import ShardTensorToMesh +from ttnn.distributed.distributed import shard_tensor_to_mesh_mapper def sharded_impl( @@ -73,7 +73,7 @@ def sharded_impl( device=device, dtype=input_dtype, layout=tensor_layout, - mesh_mapper=ShardTensorToMesh(mesh_device=device, dim=dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device=device, dim=dim), tile=ttnn.Tile(tile), ) @@ -142,7 +142,6 @@ def run_normal( device=device, dtype=input_dtype, layout=layout, - mesh_mapper=ShardTensorToMesh(mesh_device=device, dim=dim), tile=ttnn.Tile(tile), ) for i in range(num_iters): diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py index c1170936dff..7548f88ca9e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py @@ -127,7 +127,7 @@ def run_reduce_scatter_test( torch_tensor, dtype=input_dtype, layout=layout, - mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim), device=mesh_device, ) # Run the op diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index a699975ca38..ae38f5d7b49 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -487,7 +487,7 @@ def test_conv_features_multi_device( shard_layout=shard_layout, output_layout=output_layout, has_bias=True, - input_mesh_mapper=ttnn.ShardTensorToMesh(mesh_device, dim=0), + input_mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), weight_mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), output_mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0), groups=groups, diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index 2be0ec58257..55307e8608c 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -11,7 +11,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor ####### @@ -110,7 +110,7 @@ def test_ttnn_to_multi_device_multiple_times(mesh_device, layout, memory_config, torch_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) @@ -136,7 +136,7 @@ def test_ttnn_to_and_from_multi_device_shard(mesh_device, layout, memory_config, torch_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -161,7 +161,7 @@ def test_multi_device_check_per_device_shard(mesh_device, layout, memory_config, torch_tensor = torch.rand((1, 1, 32, 64 * mesh_device.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), layout=layout + torch_tensor, dtype=dtype, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), layout=layout ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -214,7 +214,7 @@ def test_ttnn_multi_device_all_gather(pcie_mesh_device): pytest.skip("Requires multiple devices to run") full_tensor = torch.rand((1, 1, 32, 32 * pcie_mesh_device.get_num_devices()), dtype=torch.bfloat16) - ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3)) + ttnn_tensor = ttnn.from_torch(full_tensor, ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device) ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) @@ -237,7 +237,7 @@ def test_multi_device_single_op_unary(mesh_device): ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), device=mesh_device, ) ttnn_output_tensor = ttnn.gelu(ttnn_input_tensor) @@ -261,13 +261,13 @@ def test_multi_device_single_op_binary(mesh_device): torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), ) ttnn_output_tensor = ttnn.add(ttnn_input_a_tensor, ttnn_input_b_tensor) @@ -289,7 +289,7 @@ def test_multi_device_multi_op(mesh_device): ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), device=mesh_device, ) ttnn_gelu_output = ttnn.gelu(ttnn_input_tensor) @@ -314,7 +314,7 @@ def test_multi_device_data_parallel_matmul_op(mesh_device): torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, @@ -349,7 +349,7 @@ def test_multi_device_as_tensor_api(mesh_device, layout, memory_config, dtype): layout=layout, memory_config=memory_config, device=mesh_device, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) with tempfile.NamedTemporaryFile() as temp_file: @@ -404,7 +404,7 @@ def test_multi_device_as_tensor_api_sharded_tensor(mesh_device, layout, memory_c device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) load_tensor = ttnn.as_tensor( input_tensor, @@ -413,7 +413,7 @@ def test_multi_device_as_tensor_api_sharded_tensor(mesh_device, layout, memory_c device=mesh_device, memory_config=memory_config, cache_file_name=f"{temp_file.name}.weight", - mesh_mapper=ShardTensorToMesh(mesh_device, dim=0), + ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0), ) torch_loaded_tensor = ttnn.to_torch(load_tensor, mesh_composer=ConcatMeshToTensor(mesh_device, dim=0)) expected_pcc = 0.98 if dtype == ttnn.bfloat4_b else 0.99 @@ -436,7 +436,7 @@ def test_multi_device_permute(mesh_device, layout, memory_config, dtype): torch_golden = torch.permute(torch_tensor, (0, 1, 3, 2)) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device, memory_config=memory_config) ttnn_permute = ttnn.permute(ttnn_tensor, (0, 1, 3, 2)) @@ -478,7 +478,7 @@ def test_ttnn_multi_device_all_gather_all_devices(t3k_mesh_device): for i in range(t3k_mesh_device.get_num_devices()): full_tensor[..., i * 32 : (i + 1) * 32] = i - ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=3)) + ttnn_tensor = ttnn.from_torch(full_tensor, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, t3k_mesh_device) ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) @@ -617,7 +617,7 @@ def test_device_shard_to_torch(mesh_device): ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3), device=mesh_device, ) @@ -687,7 +687,7 @@ def model(submesh): for i in range(submesh.get_num_devices()): full_tensor[..., i * 32 : (i + 1) * 32] = i - ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3)) + ttnn_tensor = ttnn.from_torch(full_tensor, ttnn.shard_tensor_to_mesh_mapper(submesh, dim=3)) ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh) ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 86baa78f1ef..a4822361d11 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -9,7 +9,7 @@ from loguru import logger from tests.ttnn.utils_for_testing import assert_with_pcc import transformers - +from ttnn import shard_tensor_to_mesh_mapper ####### # Multi-Device Tensor tests running in async mode @@ -21,7 +21,7 @@ @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_config, dtype): """Shard a tensor across devices, compose it back and verify loopback tensor is same as the original tensor""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor if dtype == ttnn.bfloat8_b and layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Unsupported test permutation: bfloat8_b with ROW_MAJOR_LAYOUT") @@ -31,7 +31,7 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co for i in range(100): torch_tensor = torch.rand((1, 1, 256, 512), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -48,7 +48,7 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co @pytest.mark.parametrize("dtype", [ttnn.bfloat16, ttnn.bfloat8_b]) def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_config, dtype): """This test checks if the tensor is correctly sharded across devices""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor if dtype == ttnn.bfloat8_b and layout == ttnn.ROW_MAJOR_LAYOUT: pytest.skip("Unsupported test permutation: bfloat8_b with ROW_MAJOR_LAYOUT") @@ -63,7 +63,7 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3) + torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3) ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -83,7 +83,7 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co @pytest.mark.parametrize("layout", [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT]) @pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG]) def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): - """Test ttnn.replicate_tensor_to_mesh_mapper to broadcast a tensor across multiple devices""" + """Test replicate_tensor_to_mesh_mapper to broadcast a tensor across multiple devices""" from ttnn import replicate_tensor_to_mesh_mapper pcie_mesh_device.enable_async(True) @@ -114,7 +114,7 @@ def test_multi_device_replicate(pcie_mesh_device, shape, layout, memory_config): @pytest.mark.parametrize("dtype", [ttnn.bfloat8_b]) def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_config, dtype): """Test multi chip layout conversions on worker threads""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor shard_dim = 3 pcie_mesh_device.enable_async(True) @@ -122,7 +122,7 @@ def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_ torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( torch_tensor, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=shard_dim), + ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=shard_dim), layout=layout, memory_config=memory_config, device=pcie_mesh_device, @@ -144,7 +144,7 @@ def test_ttnn_to_multi_device_tilized_parallel(pcie_mesh_device, layout, memory_ @pytest.mark.parametrize("shape", [(1, 1, 512, 512), (1, 3, 1024, 1024)]) def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, shape): """Multidevice API test: Running tensor-parallel multi-device chain of eltwise ops""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor pcie_mesh_device.enable_async(True) if program_cache: @@ -164,7 +164,7 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha ttnn_input_tensor = ttnn.from_torch( torch_input_tensor, layout=ttnn.TILE_LAYOUT, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=3), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), device=pcie_mesh_device, ) ttnn_output_tensor = ttnn.add( @@ -184,7 +184,7 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha @pytest.mark.parametrize("input_a_shape", [(4, 1, 512, 512), (16, 1, 512, 512)]) def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, input_a_shape): """Multidevice API: Running data-parallel chain of ops with matmul""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper pcie_mesh_device.enable_async(True) if program_cache: @@ -206,13 +206,13 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=0), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) ttnn_output_tensor = ttnn.from_device( ttnn.mish( @@ -264,7 +264,7 @@ def test_multi_device_argmax(pcie_mesh_device, layout, mem_config): @pytest.mark.parametrize("pcie_mesh_device", [2], indirect=True) def test_multi_device_explicit_dealloc(pcie_mesh_device): """Multidevice API: Ensure that deallocating multi-device tensors works as expected""" - from ttnn import ShardTensorToMesh, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper + from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper if pcie_mesh_device.get_num_devices() <= 1: pytest.skip("Requires multiple devices to run") @@ -278,13 +278,13 @@ def test_multi_device_explicit_dealloc(pcie_mesh_device): torch_input_a_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - mesh_mapper=ShardTensorToMesh(pcie_mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=0), ) ttnn_input_b_tensor = ttnn.from_torch( torch_input_b_tensor, layout=ttnn.TILE_LAYOUT, device=pcie_mesh_device, - ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(pcie_mesh_device), ) ttnn_output_tensor_1 = ttnn_input_a_tensor @ ttnn_input_b_tensor ttnn_output_tensor_2 = ttnn.gelu(ttnn_output_tensor_1) diff --git a/tests/ttnn/unit_tests/test_multi_device_events.py b/tests/ttnn/unit_tests/test_multi_device_events.py index 2a162932532..824282958e6 100644 --- a/tests/ttnn/unit_tests/test_multi_device_events.py +++ b/tests/ttnn/unit_tests/test_multi_device_events.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor @pytest.mark.parametrize("shape", [(1, 1, 512, 512)]) @@ -52,10 +52,10 @@ def run_op_chain(input_0, input_1, workload_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors, using data-movement CQ diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index 5ea92bf15d0..9de288ff0ce 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -84,10 +84,10 @@ def event_sync(device, record_cq, wait_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors @@ -240,10 +240,10 @@ def event_sync(device, record_cq, wait_cq): # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(t3k_mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(t3k_mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(t3k_mesh_device) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py index bf0c10fae83..dab95c7d7ac 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -80,10 +80,10 @@ def event_sync(device, record_cq, wait_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors @@ -218,10 +218,10 @@ def event_sync(device, record_cq, wait_cq): # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(mesh_device) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py index 0d83a2cbd08..e7bcd4be2a4 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py @@ -10,7 +10,7 @@ from loguru import logger import os from tests.ttnn.utils_for_testing import assert_with_pcc -from ttnn import ShardTensorToMesh, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor +from ttnn import shard_tensor_to_mesh_mapper, replicate_tensor_to_mesh_mapper, ConcatMeshToTensor NUM_TRACE_LOOPS = int(os.getenv("NUM_TRACE_LOOPS", 15)) @@ -80,10 +80,10 @@ def event_sync(device, record_cq, wait_cq): ) # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) # Copy TTNN host tensors into preallocated Mult-Device tensors @@ -217,10 +217,10 @@ def event_sync(device, record_cq, wait_cq): # Convert torch tensors to TTNN Multi-Device Host Tensors ttnn_input_tensor_0 = ttnn.from_torch( - torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_0, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_input_tensor_1 = ttnn.from_torch( - torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, mesh_mapper=ShardTensorToMesh(mesh_device, dim=0) + torch_input_tensor_1, layout=ttnn.TILE_LAYOUT, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=0) ) ttnn_weight = ttnn.from_torch( torch_weight, layout=ttnn.TILE_LAYOUT, ttnn.replicate_tensor_to_mesh_mapper(mesh_device) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py index 763a003fc7f..821faad79ab 100644 --- a/tests/ttnn/unit_tests/test_sub_device.py +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -53,7 +53,7 @@ def run_sub_devices(device, create_fabric_sub_device=False): def run_sub_devices_program(device, create_fabric_sub_device=False): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: - inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) + inputs_mesh_mapper = ttnn.shard_tensor_to_mesh_mapper(device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(device, dim=0) num_devices = device.get_num_devices() else: diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index 8fd6ef4d848..d639eae06f2 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -306,12 +306,12 @@ def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor used to map tensors to a mesh or compose tensors from a mesh. Example: - with distribute(ShardTensorToMesh(mesh_device, dim=3)): + with distribute(shard_tensor_to_mesh_mapper(mesh_device, dim=3)): # Code here will use the default mapper result = ttnn.from_torch(torch_tensor) is equivalent to: - result = ttnn.from_torch(torch_tensor, mesh_mapper=ShardTensorToMesh(mesh_device, dim=3)) + result = ttnn.from_torch(torch_tensor, ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=3)) """ _original_to_torch = ttnn.to_torch _original_from_torch = ttnn.from_torch From 29a10de678770fc5de3182c004fa40de148f288b Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 17:08:22 +0000 Subject: [PATCH 71/76] unsaved sharding switch --- models/demos/falcon7b_common/tt/falcon_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/demos/falcon7b_common/tt/falcon_model.py b/models/demos/falcon7b_common/tt/falcon_model.py index 83b2a51c278..d79cdee51d8 100644 --- a/models/demos/falcon7b_common/tt/falcon_model.py +++ b/models/demos/falcon7b_common/tt/falcon_model.py @@ -177,7 +177,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["INPUT_MEMCFG"], - ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=0), ) elif llm_mode == "decode": assert batch_size % 32 == 0, "For decode, batch_size must be multiple of 32!" @@ -226,7 +226,7 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=self.model_config["INPUT_MEMCFG"], - ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(self.mesh_device, dim=1), ) else: raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.") From 3ca0bbddb0ee7d3be63b6acc8d6982bc5b575341 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 17:30:18 +0000 Subject: [PATCH 72/76] fix replacement errors --- .../demos/llama3/tests/multimodal/test_llama_class_embedding.py | 2 +- models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py | 2 +- .../llama3/tests/multimodal/test_llama_positional_embedding.py | 2 +- .../tests/multimodal/test_llama_tile_position_embedding.py | 2 +- models/demos/llama3/tt/multimodal/llama_conv2d_patch.py | 2 +- models/demos/llama3/tt/multimodal/llama_positional_embedding.py | 2 +- .../demos/llama3/tt/multimodal/llama_tile_position_embedding.py | 2 +- models/demos/t3000/falcon40b/tests/test_falcon_decoder.py | 2 +- models/demos/t3000/falcon40b/tests/test_falcon_model.py | 2 +- tests/ttnn/distributed/test_multidevice_TG.py | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py index 7f97f631ef6..451e70be45c 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_class_embedding.py @@ -15,7 +15,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py index 520c2b30cff..eea8858a6fb 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py +++ b/models/demos/llama3/tests/multimodal/test_llama_conv2d_patch.py @@ -14,7 +14,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py index cad1136dd1c..aaa2c76d20e 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_positional_embedding.py @@ -17,7 +17,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py index 97517e1178b..c00c27773bf 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py +++ b/models/demos/llama3/tests/multimodal/test_llama_tile_position_embedding.py @@ -17,7 +17,7 @@ ##### TTNN imports ##### import ttnn from ttnn import experimental as ttl -from ttnn import ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import ConcatMeshToTensor, replicate_tensor_to_mesh_mapper from models.utility_functions import skip_for_grayskull from models.utility_functions import ( comp_pcc, diff --git a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py index 8017b9174ce..ea1cf85685e 100644 --- a/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py +++ b/models/demos/llama3/tt/multimodal/llama_conv2d_patch.py @@ -11,7 +11,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper class TtLlamaConv2dPatch(LightweightModule): diff --git a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py index 582904585ed..fea3542b12a 100644 --- a/models/demos/llama3/tt/multimodal/llama_positional_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_positional_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper TILE_SIZE = 32 diff --git a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py index 1b3171164d5..a97f20d264e 100644 --- a/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py +++ b/models/demos/llama3/tt/multimodal/llama_tile_position_embedding.py @@ -13,7 +13,7 @@ ) from models.common.lightweightmodule import LightweightModule -from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, ttnn.replicate_tensor_to_mesh_mapper +from ttnn import shard_tensor_to_mesh_mapper, ConcatMeshToTensor, replicate_tensor_to_mesh_mapper class TtLlamaTilePositionEmbedding(LightweightModule): diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py index 76b832cc29b..9b9ccc81f00 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py @@ -167,7 +167,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"], - mesh_mapper=(mesh_device, dim=-1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=-1), preprocess=lambda x: x.unsqueeze(1).transpose(0, 2), ) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_model.py b/models/demos/t3000/falcon40b/tests/test_falcon_model.py index e6cffb88639..50048cf0fdc 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_model.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_model.py @@ -136,7 +136,7 @@ def run_test_FalconModel_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_v_cache = ttnn.as_tensor( tensor=tt_v_cache_host, diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index aec169c5668..f866e145d63 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -39,7 +39,7 @@ def test_galaxy_matmul_1d_fracture(mesh_device): dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=mesh_device, - ttnn.replicate_tensor_to_mesh_mapper(mesh_device), + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device), ) weights = ttnn.from_torch( weights_pt, From 8f5246723c737266449d9d7a4e0678aa56c876ff Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 17:33:59 +0000 Subject: [PATCH 73/76] fix more replace errors --- models/demos/t3000/falcon40b/tests/test_falcon_decoder.py | 2 +- tests/ttnn/distributed/test_multidevice_TG.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py index 9b9ccc81f00..43255352663 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py @@ -192,7 +192,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=attention_mask_memconfig, - mesh_mapper=(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), preprocess=lambda x: (x.transpose(0, 2) * -1e5).expand(-1, configuration.num_attention_heads, -1, -1), ) diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index f866e145d63..4e31e2a10c6 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -13,7 +13,7 @@ from ttnn import ( shard_tensor_to_mesh_mapper, ShardTensor2dMesh, - ttnn.replicate_tensor_to_mesh_mapper, + replicate_tensor_to_mesh_mapper, ConcatMeshToTensor, ConcatMesh2dToTensor, MeshToTensor, From cf9400f9e8f8efbe22f07a5522a253e4c638c165 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 17:36:16 +0000 Subject: [PATCH 74/76] fix replace errors x3 --- models/demos/t3000/falcon40b/tests/test_falcon_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py index 43255352663..18373c8a191 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_decoder.py @@ -216,7 +216,7 @@ def run_test_FalconDecoder_inference( layout=ttnn.TILE_LAYOUT, device=mesh_device, memory_config=model_config["KV_CACHE_MEMCFG"], - mesh_mapper=(mesh_device, dim=1), + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(mesh_device, dim=1), ) tt_layer_past = (tt_k_cache, tt_v_cache) From 1a408ddca79131089ceb0d026916b453d76e153c Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 17:43:54 +0000 Subject: [PATCH 75/76] manual pre-commit --- .../demos/falcon7b_common/tt/model_utils.py | 4 +- models/demos/llama3/tt/llama_common.py | 4 +- .../demos/t3000/llama2_70b/tt/llama_common.py | 1 + .../tests/test_llama_attention_galaxy.py | 913 +++++++++--------- .../tg/llama3_70b/tt/llama_decoder_galaxy.py | 6 +- .../tg/llama3_70b/tt/llama_mlp_galaxy.py | 1 + .../tg/llama3_70b/tt/llama_model_galaxy.py | 1 + .../unit_tests/test_multi_device_async.py | 10 +- ttnn/ttnn/distributed/distributed.py | 5 +- 9 files changed, 453 insertions(+), 492 deletions(-) diff --git a/models/demos/falcon7b_common/tt/model_utils.py b/models/demos/falcon7b_common/tt/model_utils.py index 3bf7dc0919d..2b068eaeade 100644 --- a/models/demos/falcon7b_common/tt/model_utils.py +++ b/models/demos/falcon7b_common/tt/model_utils.py @@ -50,7 +50,9 @@ def preprocess_weights(weights_to_cache): layout=tt_layout, device=mesh_device, memory_config=model_config[f"{weight_config_str}_MEMCFG"], - mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if type(mesh_device) == ttnn.MeshDevice else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + if type(mesh_device) == ttnn.MeshDevice + else None, cache_file_name=str(path), preprocess=preprocess_weights, ) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 7ec888fa9b3..829d02761a9 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -402,7 +402,9 @@ def sample_host(tt_input, mesh_device, temperature=0.6, top_p=0.08, on_host=True layout=ttnn.ROW_MAJOR_LAYOUT, dtype=ttnn.uint32, device=None, - mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) if mesh_device.get_num_devices() > 1 else None, + mesh_mapper=ttnn.replicate_tensor_to_mesh_mapper(mesh_device) + if mesh_device.get_num_devices() > 1 + else None, ), pt_out, ) diff --git a/models/demos/t3000/llama2_70b/tt/llama_common.py b/models/demos/t3000/llama2_70b/tt/llama_common.py index a834b18e653..14474e541ba 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_common.py +++ b/models/demos/t3000/llama2_70b/tt/llama_common.py @@ -31,6 +31,7 @@ MeshToTensor, ) + class ConcatMesh2DToTensor(MeshToTensor): def __init__(self, mesh_device, dims, cluster_shape): self.dims = dims diff --git a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py index eba08c3cb59..e455197f1cb 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_attention_galaxy.py @@ -17,513 +17,460 @@ from models.demos.tg.llama3_70b.tt.llama_common import setup_llama_env from models.demos.t3000.llama2_70b.reference.llama.llama.model import precompute_freqs_cis from models.demos.t3000.llama2_70b.tt.llama_common import ( - check_mesh_device, - extract_pcc_from_log, - generate_rot_emb, - get_rotation_mat, - gather_cos_sin, - precompute_freqs, - MAX_SEQ_LEN, - MAX_SEQ_LEN_LLAMA3, - BASE_URL, - UNIT_TEST_N_LAYER, - UNIT_TEST_LAYER_NUM, - UNIT_TEST_START_POS, - UNIT_TEST_GENERATION_LENGTH, - comp_pcc, - get_rot_transformation_mat, - should_skip_model_load, - check_kv_cache, - num_to_corerange, - ConcatMesh2DToTensor, + check_mesh_device, + extract_pcc_from_log, + generate_rot_emb, + get_rotation_mat, + gather_cos_sin, + precompute_freqs, + MAX_SEQ_LEN, + MAX_SEQ_LEN_LLAMA3, + BASE_URL, + UNIT_TEST_N_LAYER, + UNIT_TEST_LAYER_NUM, + UNIT_TEST_START_POS, + UNIT_TEST_GENERATION_LENGTH, + comp_pcc, + get_rot_transformation_mat, + should_skip_model_load, + check_kv_cache, + num_to_corerange, + ConcatMesh2DToTensor, ) from models.utility_functions import skip_for_grayskull - - class PytorchLlamaAttentionModel(torch.nn.Module): - def __init__(self, hf_reference_model, layer_num, rope_theta): - super().__init__() - self.attention = hf_reference_model.layers[layer_num].attention - self.rope_theta = rope_theta - # Disable dropout - self.attention.eval() - - - configuration = hf_reference_model.params - self.n_heads = configuration.n_heads - hidden_dim = configuration.dim - self.head_dim = hidden_dim // self.n_heads - self.max_seq_len = configuration.max_seq_len - - - def prepare_inputs(self, x, start_pos): - """ - Prepare inputs for decode mode. Assume that current token is at - start_pos, and KV cache has valid data up to start_pos. - """ - batch = x.size(0) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) - freqs_cis = freqs_cis[start_pos : start_pos + 1] - - - attn_mask = torch.zeros(batch, 1, 1, start_pos + 1) - # attn_mask[:, :, :, : start_pos + 1] = -1e9 - attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) - - - return x, start_pos, freqs_cis, attn_mask - - - def prepare_inputs_prefill(self, x, start_pos): - """ - Prepare inputs for decode mode. Assume that current token is at - start_pos, and KV cache has valid data up to start_pos. - """ - batch = x.size(0) - seq_len = x.size(1) - freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) - freqs_cis = freqs_cis[start_pos : start_pos + seq_len] - - - attn_mask = torch.full((seq_len, seq_len), float("-inf")) - attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask = attn_mask.expand(batch, self.n_heads, -1, -1) - - - return x, start_pos, freqs_cis, attn_mask - - - def forward(self, x, start_pos, freqs_cis, mask): - """ - x: (batch, seq, hidden_dim) - start_pos: int - freqs_cis: ? - mask: ? - - - return: (batch, seq, hidden_dim) - """ - result = self.attention( - x, - start_pos, - freqs_cis, - mask, - ) - return result - - + def __init__(self, hf_reference_model, layer_num, rope_theta): + super().__init__() + self.attention = hf_reference_model.layers[layer_num].attention + self.rope_theta = rope_theta + # Disable dropout + self.attention.eval() + + configuration = hf_reference_model.params + self.n_heads = configuration.n_heads + hidden_dim = configuration.dim + self.head_dim = hidden_dim // self.n_heads + self.max_seq_len = configuration.max_seq_len + + def prepare_inputs(self, x, start_pos): + """ + Prepare inputs for decode mode. Assume that current token is at + start_pos, and KV cache has valid data up to start_pos. + """ + batch = x.size(0) + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) + freqs_cis = freqs_cis[start_pos : start_pos + 1] + + attn_mask = torch.zeros(batch, 1, 1, start_pos + 1) + # attn_mask[:, :, :, : start_pos + 1] = -1e9 + attn_mask = attn_mask.expand(-1, self.n_heads, -1, -1) + + return x, start_pos, freqs_cis, attn_mask + + def prepare_inputs_prefill(self, x, start_pos): + """ + Prepare inputs for decode mode. Assume that current token is at + start_pos, and KV cache has valid data up to start_pos. + """ + batch = x.size(0) + seq_len = x.size(1) + freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2, self.rope_theta) + freqs_cis = freqs_cis[start_pos : start_pos + seq_len] + + attn_mask = torch.full((seq_len, seq_len), float("-inf")) + attn_mask = torch.triu(attn_mask, diagonal=1) + attn_mask = attn_mask.expand(batch, self.n_heads, -1, -1) + + return x, start_pos, freqs_cis, attn_mask + + def forward(self, x, start_pos, freqs_cis, mask): + """ + x: (batch, seq, hidden_dim) + start_pos: int + freqs_cis: ? + mask: ? + + + return: (batch, seq, hidden_dim) + """ + result = self.attention( + x, + start_pos, + freqs_cis, + mask, + ) + return result def tt_llama_attention_prepare_inputs(llama_attention_model, x, start_pos, rope_theta, mode="decode"): - assert len(x.size()) == 3 - batch, seq_len, _ = x.shape - - - cache_name = lambda name: llama_attention_model.cache_path / (f"{name}") - - - if mode == "decode": - assert seq_len == 1, "Only supporting decode mode" - x = x.transpose(0, 1).unsqueeze(1) - assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size) - - - ACT_MEMCFG = ttnn.create_sharded_memory_config( - shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]), - core_grid=ttnn.CoreGrid(y=4, x=8), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - xs = ttnn.as_tensor( - x, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ACT_MEMCFG, - device=llama_attention_model.mesh_device, - mesh_mapper=shard_tensor_to_2d_mesh_mapper( - llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) - ), - ) - - - batch_size_per_group = llama_attention_model.batch_size_per_device_group - - - rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta) - rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group) - assert rot_mat.size() == ( - 1, - batch_size_per_group, - llama_attention_model.head_dim, - llama_attention_model.head_dim, - ) - - - shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch_size_per_group)}) - ROT_MAT_MEMCFG = ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.HEIGHT_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - shard_spec_n_cores_grid, - [ - llama_attention_model.head_dim, - llama_attention_model.head_dim, - ], - ttnn.ShardOrientation.ROW_MAJOR, - ), - ) - rot_mats = ttnn.as_tensor( - rot_mat, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ROT_MAT_MEMCFG, - device=llama_attention_model.mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - ) - - - attn_masks = None - - - elif mode == "prefill": - assert ( - seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192 - ), "Prefill mode only supports seqlen as a multiple of 256 up to 8k" - assert batch == 1, "prefill mode only supports batch size 1" - x = x.unsqueeze(0) - assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size) - xs = ttnn.as_tensor( - x, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=shard_tensor_to_2d_mesh_mapper( - llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) - ), - ) - - - cos, sin = precompute_freqs( - llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=False - ) - cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) - assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) - assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) - - - cos_gathereds = ttnn.as_tensor( - cos_gathered, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - ) - sin_gathereds = ttnn.as_tensor( - sin_gathered, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - ) - - - rot_mats = [cos_gathereds, sin_gathereds] - - - attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) - attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask = attn_mask.expand(1, batch, -1, -1) - attn_masks = ttnn.as_tensor( - attn_mask, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), - mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=llama_attention_model.mesh_device, - ) - - - return ( - xs, - start_pos, - rot_mats, - attn_masks, - ) - - + assert len(x.size()) == 3 + batch, seq_len, _ = x.shape + + cache_name = lambda name: llama_attention_model.cache_path / (f"{name}") + + if mode == "decode": + assert seq_len == 1, "Only supporting decode mode" + x = x.transpose(0, 1).unsqueeze(1) + assert x.shape == (seq_len, 1, batch, llama_attention_model.hidden_size) + + ACT_MEMCFG = ttnn.create_sharded_memory_config( + shape=(x.shape[2], x.shape[3] // 32 // llama_attention_model.cluster_shape[0]), + core_grid=ttnn.CoreGrid(y=4, x=8), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + xs = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ACT_MEMCFG, + device=llama_attention_model.mesh_device, + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) + ), + ) + + batch_size_per_group = llama_attention_model.batch_size_per_device_group + + rot_emb = generate_rot_emb(llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta) + rot_mat = get_rotation_mat(rot_emb, start_pos, seq_len, batch=batch_size_per_group) + assert rot_mat.size() == ( + 1, + batch_size_per_group, + llama_attention_model.head_dim, + llama_attention_model.head_dim, + ) + + shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(batch_size_per_group)}) + ROT_MAT_MEMCFG = ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + shard_spec_n_cores_grid, + [ + llama_attention_model.head_dim, + llama_attention_model.head_dim, + ], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + rot_mats = ttnn.as_tensor( + rot_mat, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ROT_MAT_MEMCFG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + + attn_masks = None + + elif mode == "prefill": + assert ( + seq_len % 256 == 0 and seq_len > 0 and seq_len <= 8192 + ), "Prefill mode only supports seqlen as a multiple of 256 up to 8k" + assert batch == 1, "prefill mode only supports batch size 1" + x = x.unsqueeze(0) + assert x.shape == (1, batch, seq_len, llama_attention_model.hidden_size) + xs = ttnn.as_tensor( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=shard_tensor_to_2d_mesh_mapper( + llama_attention_model.mesh_device, mesh_shape=llama_attention_model.cluster_shape, dims=(None, 3) + ), + ) + + cos, sin = precompute_freqs( + llama_attention_model.head_dim, llama_attention_model.max_seq_len * 2, rope_theta, use_scaled=False + ) + cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) + assert cos_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) + assert sin_gathered.size() == (1, 1, seq_len, llama_attention_model.head_dim) + + cos_gathereds = ttnn.as_tensor( + cos_gathered, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"cos_gathered_prefill_{seq_len}"), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + sin_gathereds = ttnn.as_tensor( + sin_gathered, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"sin_gathered_prefill_{seq_len}"), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + ) + + rot_mats = [cos_gathereds, sin_gathereds] + + attn_mask = torch.full((seq_len, seq_len), torch.finfo(torch.float32).min) + attn_mask = torch.triu(attn_mask, diagonal=1) + attn_mask = attn_mask.expand(1, batch, -1, -1) + attn_masks = ttnn.as_tensor( + attn_mask, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + # cache_file_name=cache_name(f"attn_mask_prefill_{seq_len}"), + mesh_mapper=replicate_tensor_to_mesh_mapper(llama_attention_model.mesh_device), + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=llama_attention_model.mesh_device, + ) + + return ( + xs, + start_pos, + rot_mats, + attn_masks, + ) def run_test_LlamaAttention_inference( - mesh_device, - cluster_shape, - batch, - seq_len, - pcc, - model_config, - llama_version, - ckpt_dir, - tokenizer_path, - cache_path, + mesh_device, + cluster_shape, + batch, + seq_len, + pcc, + model_config, + llama_version, + ckpt_dir, + tokenizer_path, + cache_path, ): - # Prepare paths and devices - skip_model_load = should_skip_model_load() - - - # Prepare configs - hugging_face_reference_model = Llama.build( - ckpt_dir, - tokenizer_path, - max_seq_len=MAX_SEQ_LEN if llama_version == "llama2" else MAX_SEQ_LEN_LLAMA3, - max_batch_size=batch, - n_layers=UNIT_TEST_N_LAYER, - skip_model_load=skip_model_load, - ).model - hugging_face_reference_model.eval() - state_dict = hugging_face_reference_model.state_dict() - logger.info(state_dict.keys()) - torch.manual_seed(0) - configuration = hugging_face_reference_model.params - - - # PyTorch model -------------------------------------------------------------------- - pytorch_LlamaAttention_model = PytorchLlamaAttentionModel( - hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta - ) - # TT model ------------------------------------------------------------------------- - transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size - - - transformation_mats = ttnn.as_tensor( - transformation_mat_torch, - dtype=ttnn.bfloat16, - layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - device=mesh_device, - mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), - ) - - - tt_LlamaAttention_model = TtLlamaAttention_galaxy( - mesh_device, - cluster_shape, - state_dict, - BASE_URL, - UNIT_TEST_LAYER_NUM, - model_config, - configuration, - transformation_mats, - cache_path=cache_path, - ) - - - mode = "decode" if seq_len == 1 else "prefill" - - - all_tests_pass, all_pccs = True, [] - if mode == "prefill": - generation_start_pos = 0 - generation_length = 1 - else: - generation_start_pos = UNIT_TEST_START_POS - generation_length = UNIT_TEST_GENERATION_LENGTH - - - for i in range(generation_length): - # Prepare input - pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len)) - pt_inp = hugging_face_reference_model.tok_embeddings(pt_inp_ids) - pt_inp_normed = hugging_face_reference_model.layers[UNIT_TEST_LAYER_NUM].attention_norm(pt_inp) - tt_input = pt_inp_normed.clone() - start_pos = generation_start_pos + i - - - # PyTorch output -------------------------------------------------------------------- - if mode == "prefill": - attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs_prefill( - pt_inp_normed, start_pos - ) - else: - attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs( - pt_inp_normed, start_pos - ) - - - pytorch_out = pytorch_LlamaAttention_model( - attention_input, - start_pos, - freqs_cis, - attn_mask, - ) - - - # TT hardware execution ------------------------------------------------------------- - attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs( - tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta, mode=mode - ) - tt_out = tt_LlamaAttention_model( - attention_input, - rot_mat, - start_pos, - attn_mask, - mode=mode, - ) - # tt_out = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(tt_out.cpu())] - - - tt_out = ttnn.to_torch( - tt_out, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(3, 1), cluster_shape=cluster_shape) - ) - tt_out = tt_out[:, 0:1, :, :] - tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] - - - does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) - logger.info(f"Output: {output_pcc}") - - - all_pccs.append(extract_pcc_from_log(output_pcc)) - - - if does_pass: - logger.info(f"[start_pos={start_pos}] {llama_version} Attention output Passed!") - else: - logger.warning( - f"[start_pos={start_pos}] {llama_version} Attention output Failed! PCC value is lower than {pcc}" - ) - all_tests_pass = False - - - logger.info(f"Average PCC over {len(all_pccs)} tokens: {sum(all_pccs) / len(all_pccs)}") - - - # Check kv cache - # PyTorch output -------------------------------------------------------------------- - pytorch_layer_present = [ - pytorch_LlamaAttention_model.attention.cache_k.clone().permute(0, 2, 1, 3)[ - :batch, ... - ], # [batch, n_kv_heads, seq, head_dim] - pytorch_LlamaAttention_model.attention.cache_v.clone().permute(0, 2, 1, 3)[ - :batch, ... - ], # [batch, n_kv_heads, seq, head_dim] - ] - # TT hardware output ---------------------------------------------------------------- - - - # concat the pasts by heads - tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] - - - tt_layer_present_all = [ - ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[ - :batch, ... - ] - for lp in tt_layer_present_all - ] - - - cache_test_pass = check_kv_cache( - pytorch_layer_present, - tt_layer_present_all, - generation_start_pos, - generation_length, - seq_len, - mode == "prefill", - pcc, - ) - - - all_tests_pass = all_tests_pass and cache_test_pass - - - if all_tests_pass: - logger.info(f"{llama_version} Attention output Passed!") - else: - gc.collect() - logger.warning(f"{llama_version} Attention output Failed!") - assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" - - + # Prepare paths and devices + skip_model_load = should_skip_model_load() + + # Prepare configs + hugging_face_reference_model = Llama.build( + ckpt_dir, + tokenizer_path, + max_seq_len=MAX_SEQ_LEN if llama_version == "llama2" else MAX_SEQ_LEN_LLAMA3, + max_batch_size=batch, + n_layers=UNIT_TEST_N_LAYER, + skip_model_load=skip_model_load, + ).model + hugging_face_reference_model.eval() + state_dict = hugging_face_reference_model.state_dict() + logger.info(state_dict.keys()) + torch.manual_seed(0) + configuration = hugging_face_reference_model.params + + # PyTorch model -------------------------------------------------------------------- + pytorch_LlamaAttention_model = PytorchLlamaAttentionModel( + hugging_face_reference_model, UNIT_TEST_LAYER_NUM, configuration.rope_theta + ) + # TT model ------------------------------------------------------------------------- + transformation_mat_torch = get_rot_transformation_mat(32) # 32 for tile size + + transformation_mats = ttnn.as_tensor( + transformation_mat_torch, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + device=mesh_device, + mesh_mapper=replicate_tensor_to_mesh_mapper(mesh_device), + ) + + tt_LlamaAttention_model = TtLlamaAttention_galaxy( + mesh_device, + cluster_shape, + state_dict, + BASE_URL, + UNIT_TEST_LAYER_NUM, + model_config, + configuration, + transformation_mats, + cache_path=cache_path, + ) + + mode = "decode" if seq_len == 1 else "prefill" + + all_tests_pass, all_pccs = True, [] + if mode == "prefill": + generation_start_pos = 0 + generation_length = 1 + else: + generation_start_pos = UNIT_TEST_START_POS + generation_length = UNIT_TEST_GENERATION_LENGTH + + for i in range(generation_length): + # Prepare input + pt_inp_ids = torch.randint(0, configuration.vocab_size, (batch, seq_len)) + pt_inp = hugging_face_reference_model.tok_embeddings(pt_inp_ids) + pt_inp_normed = hugging_face_reference_model.layers[UNIT_TEST_LAYER_NUM].attention_norm(pt_inp) + tt_input = pt_inp_normed.clone() + start_pos = generation_start_pos + i + + # PyTorch output -------------------------------------------------------------------- + if mode == "prefill": + attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs_prefill( + pt_inp_normed, start_pos + ) + else: + attention_input, start_pos, freqs_cis, attn_mask = pytorch_LlamaAttention_model.prepare_inputs( + pt_inp_normed, start_pos + ) + + pytorch_out = pytorch_LlamaAttention_model( + attention_input, + start_pos, + freqs_cis, + attn_mask, + ) + + # TT hardware execution ------------------------------------------------------------- + attention_input, start_pos, rot_mat, attn_mask = tt_llama_attention_prepare_inputs( + tt_LlamaAttention_model, tt_input, start_pos, configuration.rope_theta, mode=mode + ) + tt_out = tt_LlamaAttention_model( + attention_input, + rot_mat, + start_pos, + attn_mask, + mode=mode, + ) + # tt_out = [ttnn.to_torch(shard) for shard in ttnn.get_device_tensors(tt_out.cpu())] + + tt_out = ttnn.to_torch( + tt_out, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(3, 1), cluster_shape=cluster_shape) + ) + tt_out = tt_out[:, 0:1, :, :] + tt_out = tt_out.permute(2, 1, 0, 3).squeeze(1) # [seq, batch, hidden_dim] + + does_pass, output_pcc = comp_pcc(pytorch_out, tt_out, pcc) + logger.info(f"Output: {output_pcc}") + + all_pccs.append(extract_pcc_from_log(output_pcc)) + + if does_pass: + logger.info(f"[start_pos={start_pos}] {llama_version} Attention output Passed!") + else: + logger.warning( + f"[start_pos={start_pos}] {llama_version} Attention output Failed! PCC value is lower than {pcc}" + ) + all_tests_pass = False + + logger.info(f"Average PCC over {len(all_pccs)} tokens: {sum(all_pccs) / len(all_pccs)}") + + # Check kv cache + # PyTorch output -------------------------------------------------------------------- + pytorch_layer_present = [ + pytorch_LlamaAttention_model.attention.cache_k.clone().permute(0, 2, 1, 3)[ + :batch, ... + ], # [batch, n_kv_heads, seq, head_dim] + pytorch_LlamaAttention_model.attention.cache_v.clone().permute(0, 2, 1, 3)[ + :batch, ... + ], # [batch, n_kv_heads, seq, head_dim] + ] + # TT hardware output ---------------------------------------------------------------- + + # concat the pasts by heads + tt_layer_present_all = [ttnn.from_device(lp) for lp in tt_LlamaAttention_model.layer_past] + + tt_layer_present_all = [ + ttnn.to_torch(lp, mesh_composer=ConcatMesh2DToTensor(mesh_device, dims=(0, 1), cluster_shape=cluster_shape))[ + :batch, ... + ] + for lp in tt_layer_present_all + ] + + cache_test_pass = check_kv_cache( + pytorch_layer_present, + tt_layer_present_all, + generation_start_pos, + generation_length, + seq_len, + mode == "prefill", + pcc, + ) + + all_tests_pass = all_tests_pass and cache_test_pass + + if all_tests_pass: + logger.info(f"{llama_version} Attention output Passed!") + else: + gc.collect() + logger.warning(f"{llama_version} Attention output Failed!") + assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" @skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.parametrize( - "cluster_shape, mesh_device", [pytest.param((4, 8), (8, 4), id="4x8_grid")], indirect=["mesh_device"] + "cluster_shape, mesh_device", [pytest.param((4, 8), (8, 4), id="4x8_grid")], indirect=["mesh_device"] ) @pytest.mark.parametrize( - "llama_version", - (("llama3-tg"),), + "llama_version", + (("llama3-tg"),), ) @pytest.mark.parametrize( - "batch, seq_len, pcc", - [ - (32, 1, 0.9995), - (1, 256, 0.999), - ], - ids=[ - "decode", - "prefill", - ], + "batch, seq_len, pcc", + [ + (32, 1, 0.9995), + (1, 256, 0.999), + ], + ids=[ + "decode", + "prefill", + ], ) @pytest.mark.parametrize( - "max_batch_size, max_context_len", - ( - (32, 2048), - # (16, 8192), - ), - ids=( - "short_context", - # "long_context", - ), + "max_batch_size, max_context_len", + ( + (32, 2048), + # (16, 8192), + ), + ids=( + "short_context", + # "long_context", + ), ) def test_LlamaAttention_inference( - batch, - seq_len, - pcc, - mesh_device, - max_batch_size, - max_context_len, - llama_version, - cluster_shape, - use_program_cache, + batch, + seq_len, + pcc, + mesh_device, + max_batch_size, + max_context_len, + llama_version, + cluster_shape, + use_program_cache, ): - if batch > max_batch_size: - pytest.skip(f"Decode with {batch} users is not supported with large context") - - - if batch == 1 and seq_len > max_context_len: - pytest.skip(f"Prefill with {seq_len=} is not supported with short context") - - - if llama_version == "llama2" and seq_len > 2048: - pytest.skip(f"Llama2 with {seq_len=} is not supported (max 2048)") - - - model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( - llama_version=llama_version, - max_batch_size=max_batch_size, - max_context_len=max_context_len, - ) - check_mesh_device(mesh_device, model_config) - run_test_LlamaAttention_inference( - mesh_device, - cluster_shape, - batch, - seq_len, - pcc, - model_config, - llama_version, - ckpt_dir, - tokenizer_path, - cache_path, - ) + if batch > max_batch_size: + pytest.skip(f"Decode with {batch} users is not supported with large context") + + if batch == 1 and seq_len > max_context_len: + pytest.skip(f"Prefill with {seq_len=} is not supported with short context") + + if llama_version == "llama2" and seq_len > 2048: + pytest.skip(f"Llama2 with {seq_len=} is not supported (max 2048)") + + model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( + llama_version=llama_version, + max_batch_size=max_batch_size, + max_context_len=max_context_len, + ) + check_mesh_device(mesh_device, model_config) + run_test_LlamaAttention_inference( + mesh_device, + cluster_shape, + batch, + seq_len, + pcc, + model_config, + llama_version, + ckpt_dir, + tokenizer_path, + cache_path, + ) diff --git a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py index 5d3fae48334..2404016e361 100644 --- a/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py @@ -110,8 +110,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=shard_tensor_to_2d_mesh_mapper -(self.mesh_device, self.cluster_shape, (None, 2)), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / attn_norm_sharded_str, ) @@ -121,8 +120,7 @@ def load_weights(self): layout=ttnn.ROW_MAJOR_LAYOUT, device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, - mesh_mapper=shard_tensor_to_2d_mesh_mapper -(self.mesh_device, self.cluster_shape, (None, 2)), + mesh_mapper=shard_tensor_to_2d_mesh_mapper(self.mesh_device, self.cluster_shape, (None, 2)), cache_file_name=self.cache_path / ffn_norm_sharded_str, ) diff --git a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py index de8c1b3d11d..068ad25c44d 100644 --- a/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py @@ -12,6 +12,7 @@ ) from ttnn import shard_tensor_to_2d_mesh_mapper + class TtLlamaMLP_galaxy: def __init__( self, diff --git a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py index 429673467cf..d5abeb42724 100644 --- a/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py +++ b/models/demos/tg/llama3_70b/tt/llama_model_galaxy.py @@ -26,6 +26,7 @@ ) from ttnn import shard_tensor_to_2d_mesh_mapper + def is_power_of_two(n): if n <= 0: return False diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index a4822361d11..8f9e8c5b5df 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -31,7 +31,10 @@ def test_ttnn_to_and_from_multi_device_shard(pcie_mesh_device, layout, memory_co for i in range(100): torch_tensor = torch.rand((1, 1, 256, 512), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3) + torch_tensor, + dtype=dtype, + layout=layout, + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) @@ -63,7 +66,10 @@ def test_multi_device_check_per_device_shard(pcie_mesh_device, layout, memory_co torch_tensor = torch.rand((8, 1, 1024, 1024), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch( - torch_tensor, dtype=dtype, layout=layout, mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3) + torch_tensor, + dtype=dtype, + layout=layout, + mesh_mapper=ttnn.shard_tensor_to_mesh_mapper(pcie_mesh_device, dim=3), ) ttnn_tensor = ttnn.to_device(ttnn_tensor, pcie_mesh_device, memory_config=memory_config) ttnn_loop_back_tensor = ttnn.from_device(ttnn_tensor) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index d639eae06f2..84abb481e26 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -189,7 +189,8 @@ def create_mesh_device(*args, **kwargs): yield mesh_device finally: close_mesh_device(mesh_device) - + + def synchronize_devices( devices: Union["ttnn.Device", "ttnn.MeshDevice"], queue_id: Optional[int] = ttnn.DefaultQueueId, @@ -293,6 +294,7 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": ] return torch.cat(device_shards_converted_to_torch, dim=self.concat_dim) + @contextlib.contextmanager def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor]): """ @@ -330,4 +332,5 @@ def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor ttnn.from_torch = _original_from_torch ttnn.to_torch = _original_to_torch + __all__ = [] From 3746a28b2d0255a37a70c5a44051e0b187131714 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Thu, 6 Mar 2025 20:58:37 +0000 Subject: [PATCH 76/76] add back distributed to imports, fix it --- ttnn/ttnn/distributed/__init__.py | 1 + ttnn/ttnn/distributed/distributed.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ttnn/ttnn/distributed/__init__.py b/ttnn/ttnn/distributed/__init__.py index 9306b213785..ab0b239154d 100644 --- a/ttnn/ttnn/distributed/__init__.py +++ b/ttnn/ttnn/distributed/__init__.py @@ -18,4 +18,5 @@ create_mesh_device, synchronize_devices, visualize_mesh_device, + distribute, ) diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index 84abb481e26..ec7480a4564 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -296,7 +296,7 @@ def compose(self, tensor: ttnn.Tensor) -> "torch.Tensor": @contextlib.contextmanager -def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor]): +def distribute(default: Union[ttnn.TensorToMesh, ttnn.CppMeshToTensor, MeshToTensor]): """ Context manager to temporarily modify the behavior of ttnn.from_torch and ttnn.to_torch to use the specified mesh_mapper or mesh_composer for tensor distribution and composition to/from MeshDevice. @@ -319,9 +319,9 @@ def distribute(default: Union[ttnn.TensorToMesh, ttnn.MeshToTensor, MeshToTensor _original_from_torch = ttnn.from_torch try: - if isinstance(default, ttnn.TensorToMesh) or isinstance(default, ttnn.MeshToTensor): + if isinstance(default, ttnn.TensorToMesh): ttnn.from_torch = functools.partial(_original_from_torch, mesh_mapper=default) - elif isinstance(default, MeshToTensor): + elif isinstance(default, MeshToTensor) or isinstance(default, ttnn.CppMeshToTensor): ttnn.to_torch = functools.partial(_original_to_torch, mesh_composer=default) else: raise ValueError("Argument must be an instance of either TensorToMesh or MeshToTensor.")