Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Add borrowed storage support for the aggregate_as_tensor function #18555

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

jjiangTT
Copy link
Contributor

@jjiangTT jjiangTT commented Mar 3, 2025

Ticket

Link to Github Issue

Problem description

The aggregate_as_tensor function currently goes straight to assuming anything that isn't an owned tensor is a device tensor and then failing. Borrowed storage is currently the "preferred default" so newly created non-bfp_b types would require a conversion before they could be aggregated otherwise.

What's changed

Added support for borrowed storage tensors by turning them into owned tensors and then aggregating them into a multidevicehost storage backed tensor.
Added TT_FATAL checking to prevent any other types from crashing on the storage variant fetch and give some more visibility into failures.

Checklist

Comment on lines 103 to 110
auto visitor = tt::stl::overloaded{[&shard, &host_owned_buffers](const auto& buffer) -> OwnedBuffer {
using BufferType = std::decay_t<decltype(buffer)>;
using ValueType = typename BufferType::value_type;

std::vector<ValueType> physical_data(buffer.begin(), buffer.end());

std::vector<ValueType> logical_data =
tensor_impl::decode_tensor_data(std::move(physical_data), shard.get_tensor_spec());

return owned_buffer::create(std::move(logical_data));
}};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this? Does this work:

auto borrowed_buffer = std::get<BorrowedStorage>(shard.get_storage()).buffer;
auto owned_buffer =  std::visit(
            [](auto&& buffer) {
                using BorrowedStorageType = std::vector<std::decay_t<decltype(*(buffer.begin()))>>;
                return owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end()));
            },
            borrowed_buffer);

@@ -93,7 +93,31 @@ Tensor aggregate_as_tensor(
}
auto storage = MultiDeviceHostStorage{config, std::move(host_owned_buffers), specs};
return Tensor(std::move(storage), reference_shard.get_tensor_spec());
} else if (storage_type == StorageType::BORROWED) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is formatted?

@jjiangTT jjiangTT force-pushed the jjiang/support_borrowed_aggregate branch 2 times, most recently from eb73a5e to 84c65ae Compare March 3, 2025 21:54
@jjiangTT jjiangTT requested a review from omilyutin-tt March 3, 2025 21:54
@jjiangTT jjiangTT force-pushed the jjiang/support_borrowed_aggregate branch 5 times, most recently from cf2f060 to e2a46b1 Compare March 6, 2025 18:47
@jjiangTT jjiangTT force-pushed the jjiang/support_borrowed_aggregate branch from 4e5f8a9 to 0f8151f Compare March 6, 2025 23:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants