diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc index 0e61c48189c062..d7c6872034be85 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc +++ b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.cc @@ -14,7 +14,9 @@ #include "tensorflow/lite/experimental/litert/cc/litert_compiled_model.h" +#include #include +#include #include #include #include @@ -23,10 +25,9 @@ #include "absl/strings/string_view.h" #include "tensorflow/lite/experimental/litert/c/litert_common.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" -#include "tensorflow/lite/experimental/litert/c/litert_model.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" -#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" +#include "tensorflow/lite/experimental/litert/cc/litert_macros.h" #include "tensorflow/lite/experimental/litert/cc/litert_model.h" #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" @@ -35,28 +36,24 @@ namespace litert { Expected CompiledModel::FindInputIndex( size_t signature_index, absl::string_view input_name) const { - auto signature = model_.GetSignature(signature_index); - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); - } - for (int i = 0; i < signature->InputNames().size(); ++i) { - if (signature->InputNames()[i] == input_name) { - return i; - } + LITERT_ASSIGN_OR_RETURN(const Signature& signature, + model_.GetSignature(signature_index)); + const std::vector& input_names = signature.InputNames(); + auto it = std::find(input_names.begin(), input_names.end(), input_name); + if (it != input_names.end()) { + return std::distance(input_names.begin(), it); } return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); } Expected CompiledModel::FindOutputIndex( size_t signature_index, absl::string_view output_name) const { - auto signature = model_.GetSignature(signature_index); - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); - } - for (int i = 0; i < signature->OutputNames().size(); ++i) { - if (signature->OutputNames()[i] == output_name) { - return i; - } + LITERT_ASSIGN_OR_RETURN(const Signature& signature, + model_.GetSignature(signature_index)); + const std::vector& output_names = signature.OutputNames(); + auto it = std::find(output_names.begin(), output_names.end(), output_name); + if (it != output_names.end()) { + return std::distance(output_names.begin(), it); } return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); } @@ -64,149 +61,65 @@ Expected CompiledModel::FindOutputIndex( Expected CompiledModel::CreateBufferImpl( const TensorBufferRequirements& buffer_requirements, const RankedTensorType& tensor_type) { - auto supported_types = buffer_requirements.SupportedTypes(); - if (!supported_types) { - return supported_types.Error(); - } - if (supported_types->empty()) { + LITERT_ASSIGN_OR_RETURN( + const std::vector& supported_types, + buffer_requirements.SupportedTypes()); + if (supported_types.empty()) { return Unexpected(kLiteRtStatusErrorRuntimeFailure, "Input doesn't support any tensor buffer types"); } // For simplicity we just pick the first supported tensor buffer type. - LiteRtTensorBufferType tensor_buffer_type = (*supported_types)[0]; + LiteRtTensorBufferType tensor_buffer_type = supported_types[0]; + LITERT_ASSIGN_OR_RETURN(size_t buffer_size, buffer_requirements.BufferSize()); - auto buffer = - TensorBuffer::CreateManaged(tensor_buffer_type, tensor_type, - buffer_requirements.BufferSize().Value()); - if (!buffer) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - buffer.Error().Message()); - } - - return std::move(*buffer); + LITERT_ASSIGN_OR_RETURN(TensorBuffer buffer, + TensorBuffer::CreateManaged( + tensor_buffer_type, tensor_type, buffer_size)); + return buffer; } Expected CompiledModel::CreateInputOutputBuffer( absl::string_view signature_name, absl::string_view tensor_name, bool is_input) const { - auto signature_index = model_.GetSignatureIndex(signature_name); - if (!signature_index) { - return signature_index.Error(); - } - auto signature = model_.GetSignature(*signature_index); - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); - } - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); - } - auto subgraph = model_.Subgraph(signature->Key()); - if (!subgraph) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); - } - - LiteRtTensor target_litert_tensor; - LiteRtTensorBufferRequirements litert_buffer_requirements; - if (is_input) { - auto input_tensor = subgraph->Input(tensor_name); - if (!input_tensor) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); - } - target_litert_tensor = input_tensor->Get(); - auto input_buffer_requirements = - GetInputBufferRequirements(*signature_index, tensor_name); - if (!input_buffer_requirements) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - input_buffer_requirements.Error().Message()); - } - litert_buffer_requirements = input_buffer_requirements->Get(); - } else { - auto output_tensor = subgraph->Output(tensor_name); - if (!output_tensor) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); - } - target_litert_tensor = output_tensor->Get(); - auto output_buffer_requirements = - GetOutputBufferRequirements(*signature_index, tensor_name); - if (!output_buffer_requirements) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - output_buffer_requirements.Error().Message()); - } - litert_buffer_requirements = output_buffer_requirements->Get(); - } - - auto buffer_requirements = - TensorBufferRequirements(litert_buffer_requirements, /*owned=*/false); - auto target_tensor = Tensor(target_litert_tensor); - auto tensor_type = target_tensor.RankedTensorType(); - if (!tensor_type) { - return tensor_type.Error(); - } - return CreateBufferImpl(buffer_requirements, *tensor_type); + LITERT_ASSIGN_OR_RETURN(size_t signature_index, + model_.GetSignatureIndex(signature_name)); + LITERT_ASSIGN_OR_RETURN(Signature signature, + model_.GetSignature(signature_index)); + + LITERT_ASSIGN_OR_RETURN(Subgraph subgraph, model_.Subgraph(signature.Key())); + + Expected tensor_expected = + is_input ? subgraph.Input(tensor_name) : subgraph.Output(tensor_name); + Expected buffer_requirements_expected = + is_input ? GetInputBufferRequirements(signature_index, tensor_name) + : GetOutputBufferRequirements(signature_index, tensor_name); + + LITERT_ASSIGN_OR_RETURN(const Tensor& tensor, tensor_expected); + LITERT_ASSIGN_OR_RETURN(const TensorBufferRequirements& buffer_requirements, + buffer_requirements_expected); + LITERT_ASSIGN_OR_RETURN(const RankedTensorType& tensor_type, + tensor.RankedTensorType()); + + return CreateBufferImpl(buffer_requirements, tensor_type); } Expected> CompiledModel::CreateInputOutputBuffers( size_t signature_index, bool is_input) const { - auto signature = model_.GetSignature(signature_index); - if (!signature) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find signature"); - } - auto subgraph = model_.Subgraph(signature->Key()); - if (!subgraph) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to get subgraph"); - } + LITERT_ASSIGN_OR_RETURN(const Signature& signature, + model_.GetSignature(signature_index)); + LITERT_ASSIGN_OR_RETURN(const Subgraph subgraph, + model_.Subgraph(signature.Key())); std::vector tensor_buffers; std::vector tensor_names; - if (is_input) { - tensor_names = signature->InputNames(); - } else { - tensor_names = signature->OutputNames(); - } + + tensor_names = is_input ? signature.InputNames() : signature.OutputNames(); tensor_buffers.reserve(tensor_names.size()); for (int i = 0; i < tensor_names.size(); ++i) { - LiteRtTensor target_litert_tensor; - LiteRtTensorBufferRequirements litert_buffer_requirements; - if (is_input) { - auto input_buffer_requirements = - GetInputBufferRequirements(signature_index, i); - if (!input_buffer_requirements) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - input_buffer_requirements.Error().Message()); - } - litert_buffer_requirements = input_buffer_requirements->Get(); - auto input_tensor = subgraph->Input(tensor_names[i]); - if (!input_tensor) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); - } - target_litert_tensor = input_tensor->Get(); - } else { - auto output_buffer_requirements = - GetOutputBufferRequirements(signature_index, i); - if (!output_buffer_requirements) { - return Unexpected(kLiteRtStatusErrorRuntimeFailure, - output_buffer_requirements.Error().Message()); - } - litert_buffer_requirements = output_buffer_requirements->Get(); - auto output_tensor = subgraph->Output(tensor_names[i]); - if (!output_tensor) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); - } - target_litert_tensor = output_tensor->Get(); - } - - auto buffer_requirements = - TensorBufferRequirements(litert_buffer_requirements, /*owned=*/false); - auto target_tensor = Tensor(target_litert_tensor); - auto tensor_type = target_tensor.RankedTensorType(); - if (!tensor_type) { - return tensor_type.Error(); - } - auto tensor_buffer = CreateBufferImpl(buffer_requirements, *tensor_type); - if (!tensor_buffer) { - return tensor_buffer.Error(); - } - tensor_buffers.push_back(std::move(*tensor_buffer)); + LITERT_ASSIGN_OR_RETURN( + TensorBuffer tensor_buffer, + CreateInputOutputBuffer(signature.Key(), tensor_names[i], is_input)); + tensor_buffers.push_back(std::move(tensor_buffer)); } return tensor_buffers;