Skip to content

Commit

Permalink
simplify borrowed aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
jjiangTT committed Mar 3, 2025
1 parent 68db73c commit 84c65ae
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
#include <memory>

#include <tt-metalium/overloaded.hpp>
#include "tt-metalium/assert.hpp"
#include "tt-metalium/mesh_coord.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/tensor_utils.hpp"
#include "ttnn/distributed/distributed_tensor_config.hpp"
#include <tt-metalium/mesh_device.hpp>
#include <tt-metalium/system_mesh.hpp>
#include "ttnn/distributed/distributed_tensor_config.hpp"


using namespace tt::tt_metal;

namespace ttnn::distributed {
Expand Down Expand Up @@ -101,15 +104,9 @@ Tensor aggregate_as_tensor(
specs.push_back(shard.get_tensor_spec());

auto visitor = tt::stl::overloaded{[&shard, &host_owned_buffers](const auto& buffer) -> OwnedBuffer {
using BufferType = std::decay_t<decltype(buffer)>;
using ValueType = typename BufferType::value_type;

std::vector<ValueType> physical_data(buffer.begin(), buffer.end());

std::vector<ValueType> logical_data =
tensor_impl::decode_tensor_data(std::move(physical_data), shard.get_tensor_spec());
using BorrowedBufferType = std::vector<typename std::decay_t<decltype(buffer)>::value_type>;

return owned_buffer::create(std::move(logical_data));
return owned_buffer::create(BorrowedBufferType(buffer.begin(), buffer.end()));
}};

host_owned_buffers.push_back(std::visit(visitor, buffer));
Expand Down

0 comments on commit 84c65ae

Please sign in to comment.