Skip to content

Commit

Permalink
Refactor CC CompiledModel buffer creation functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729672089
  • Loading branch information
tensorflower-gardener committed Feb 21, 2025
1 parent 9c9dee8 commit 3bcb15f
Showing 1 changed file with 55 additions and 142 deletions.
197 changes: 55 additions & 142 deletions tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h"

#include <algorithm>
#include <cstddef>
#include <iterator>
#include <memory>
#include <utility>
#include <vector>
Expand All @@ -23,10 +25,9 @@
#include "absl/strings/string_view.h"
#include "tensorflow/lite/experimental/litert/c/litert_common.h"
#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h"
#include "tensorflow/lite/experimental/litert/c/litert_model.h"
#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h"
#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h"
#include "tensorflow/lite/experimental/litert/cc/litert_expected.h"
#include "tensorflow/lite/experimental/litert/cc/litert_macros.h"
#include "tensorflow/lite/experimental/litert/cc/litert_model.h"
#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h"
#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h"
Expand All @@ -35,178 +36,90 @@ namespace litert {

Expected<size_t> CompiledModel::FindInputIndex(
size_t signature_index, absl::string_view input_name) const {
auto signature = model_.GetSignature(signature_index);
if (!signature) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature");
}
for (int i = 0; i < signature->InputNames().size(); ++i) {
if (signature->InputNames()[i] == input_name) {
return i;
}
LITERT_ASSIGN_OR_RETURN(const Signature& signature,
model_.GetSignature(signature_index));
const std::vector<absl::string_view>& input_names = signature.InputNames();
auto it = std::find(input_names.begin(), input_names.end(), input_name);
if (it != input_names.end()) {
return std::distance(input_names.begin(), it);
}
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input");
}

Expected<size_t> CompiledModel::FindOutputIndex(
size_t signature_index, absl::string_view output_name) const {
auto signature = model_.GetSignature(signature_index);
if (!signature) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature");
}
for (int i = 0; i < signature->OutputNames().size(); ++i) {
if (signature->OutputNames()[i] == output_name) {
return i;
}
LITERT_ASSIGN_OR_RETURN(const Signature& signature,
model_.GetSignature(signature_index));
const std::vector<absl::string_view>& output_names = signature.OutputNames();
auto it = std::find(output_names.begin(), output_names.end(), output_name);
if (it != output_names.end()) {
return std::distance(output_names.begin(), it);
}
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output");
}

Expected<TensorBuffer> CompiledModel::CreateBufferImpl(
const TensorBufferRequirements& buffer_requirements,
const RankedTensorType& tensor_type) {
auto supported_types = buffer_requirements.SupportedTypes();
if (!supported_types) {
return supported_types.Error();
}
if (supported_types->empty()) {
LITERT_ASSIGN_OR_RETURN(
const std::vector<LiteRtTensorBufferType>& supported_types,
buffer_requirements.SupportedTypes());
if (supported_types.empty()) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
"Input doesn't support any tensor buffer types");
}
// For simplicity we just pick the first supported tensor buffer type.
LiteRtTensorBufferType tensor_buffer_type = (*supported_types)[0];
LiteRtTensorBufferType tensor_buffer_type = supported_types[0];
LITERT_ASSIGN_OR_RETURN(size_t buffer_size, buffer_requirements.BufferSize());

auto buffer =
TensorBuffer::CreateManaged(tensor_buffer_type, tensor_type,
buffer_requirements.BufferSize().Value());
if (!buffer) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
buffer.Error().Message());
}

return std::move(*buffer);
LITERT_ASSIGN_OR_RETURN(TensorBuffer buffer,
TensorBuffer::CreateManaged(
tensor_buffer_type, tensor_type, buffer_size));
return buffer;
}

Expected<TensorBuffer> CompiledModel::CreateInputOutputBuffer(
absl::string_view signature_name, absl::string_view tensor_name,
bool is_input) const {
auto signature_index = model_.GetSignatureIndex(signature_name);
if (!signature_index) {
return signature_index.Error();
}
auto signature = model_.GetSignature(*signature_index);
if (!signature) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature");
}
if (!signature) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature");
}
auto subgraph = model_.Subgraph(signature->Key());
if (!subgraph) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph");
}

LiteRtTensor target_litert_tensor;
LiteRtTensorBufferRequirements litert_buffer_requirements;
if (is_input) {
auto input_tensor = subgraph->Input(tensor_name);
if (!input_tensor) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input");
}
target_litert_tensor = input_tensor->Get();
auto input_buffer_requirements =
GetInputBufferRequirements(*signature_index, tensor_name);
if (!input_buffer_requirements) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
input_buffer_requirements.Error().Message());
}
litert_buffer_requirements = input_buffer_requirements->Get();
} else {
auto output_tensor = subgraph->Output(tensor_name);
if (!output_tensor) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output");
}
target_litert_tensor = output_tensor->Get();
auto output_buffer_requirements =
GetOutputBufferRequirements(*signature_index, tensor_name);
if (!output_buffer_requirements) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
output_buffer_requirements.Error().Message());
}
litert_buffer_requirements = output_buffer_requirements->Get();
}

auto buffer_requirements =
TensorBufferRequirements(litert_buffer_requirements, /*owned=*/false);
auto target_tensor = Tensor(target_litert_tensor);
auto tensor_type = target_tensor.RankedTensorType();
if (!tensor_type) {
return tensor_type.Error();
}
return CreateBufferImpl(buffer_requirements, *tensor_type);
LITERT_ASSIGN_OR_RETURN(size_t signature_index,
model_.GetSignatureIndex(signature_name));
LITERT_ASSIGN_OR_RETURN(Signature signature,
model_.GetSignature(signature_index));

LITERT_ASSIGN_OR_RETURN(Subgraph subgraph, model_.Subgraph(signature.Key()));

Expected<Tensor> tensor_expected =
is_input ? subgraph.Input(tensor_name) : subgraph.Output(tensor_name);
Expected<TensorBufferRequirements> buffer_requirements_expected =
is_input ? GetInputBufferRequirements(signature_index, tensor_name)
: GetOutputBufferRequirements(signature_index, tensor_name);

LITERT_ASSIGN_OR_RETURN(const Tensor& tensor, tensor_expected);
LITERT_ASSIGN_OR_RETURN(const TensorBufferRequirements& buffer_requirements,
buffer_requirements_expected);
LITERT_ASSIGN_OR_RETURN(const RankedTensorType& tensor_type,
tensor.RankedTensorType());

return CreateBufferImpl(buffer_requirements, tensor_type);
}

Expected<std::vector<TensorBuffer>> CompiledModel::CreateInputOutputBuffers(
size_t signature_index, bool is_input) const {
auto signature = model_.GetSignature(signature_index);
if (!signature) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature");
}
auto subgraph = model_.Subgraph(signature->Key());
if (!subgraph) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph");
}
LITERT_ASSIGN_OR_RETURN(const Signature& signature,
model_.GetSignature(signature_index));
LITERT_ASSIGN_OR_RETURN(const Subgraph subgraph,
model_.Subgraph(signature.Key()));
std::vector<TensorBuffer> tensor_buffers;
std::vector<absl::string_view> tensor_names;
if (is_input) {
tensor_names = signature->InputNames();
} else {
tensor_names = signature->OutputNames();
}

tensor_names = is_input ? signature.InputNames() : signature.OutputNames();
tensor_buffers.reserve(tensor_names.size());

for (int i = 0; i < tensor_names.size(); ++i) {
LiteRtTensor target_litert_tensor;
LiteRtTensorBufferRequirements litert_buffer_requirements;
if (is_input) {
auto input_buffer_requirements =
GetInputBufferRequirements(signature_index, i);
if (!input_buffer_requirements) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
input_buffer_requirements.Error().Message());
}
litert_buffer_requirements = input_buffer_requirements->Get();
auto input_tensor = subgraph->Input(tensor_names[i]);
if (!input_tensor) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input");
}
target_litert_tensor = input_tensor->Get();
} else {
auto output_buffer_requirements =
GetOutputBufferRequirements(signature_index, i);
if (!output_buffer_requirements) {
return Unexpected(kLiteRtStatusErrorRuntimeFailure,
output_buffer_requirements.Error().Message());
}
litert_buffer_requirements = output_buffer_requirements->Get();
auto output_tensor = subgraph->Output(tensor_names[i]);
if (!output_tensor) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output");
}
target_litert_tensor = output_tensor->Get();
}

auto buffer_requirements =
TensorBufferRequirements(litert_buffer_requirements, /*owned=*/false);
auto target_tensor = Tensor(target_litert_tensor);
auto tensor_type = target_tensor.RankedTensorType();
if (!tensor_type) {
return tensor_type.Error();
}
auto tensor_buffer = CreateBufferImpl(buffer_requirements, *tensor_type);
if (!tensor_buffer) {
return tensor_buffer.Error();
}
tensor_buffers.push_back(std::move(*tensor_buffer));
LITERT_ASSIGN_OR_RETURN(
TensorBuffer tensor_buffer,
CreateInputOutputBuffer(signature.Key(), tensor_names[i], is_input));
tensor_buffers.push_back(std::move(tensor_buffer));
}

return tensor_buffers;
Expand Down

0 comments on commit 3bcb15f

Please sign in to comment.