From 3666a0475a3e5cbc3ec1eec1f90c403899d99132 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 21 Feb 2025 15:18:13 -0800 Subject: [PATCH] Refactor CC Compiled Model buffer requirements functions. PiperOrigin-RevId: 729675545 --- .../litert/cc/litert_compiled_model.h | 45 +++++++------------ 1 file changed, 16 insertions(+), 29 deletions(-) diff --git a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h index a22cecd2125e62..c7b0682b7c7d4c 100644 --- a/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h +++ b/tensorflow/lite/experimental/litert/cc/litert_compiled_model.h @@ -27,6 +27,8 @@ #include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h" #include "tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h" #include "tensorflow/lite/experimental/litert/c/litert_environment.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" #include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" #include "tensorflow/lite/experimental/litert/cc/litert_environment.h" #include "tensorflow/lite/experimental/litert/cc/litert_expected.h" @@ -131,19 +133,16 @@ class CompiledModel Options&& compilation_options) { LiteRtModel litert_model = model.Get(); LiteRtCompiledModel compiled_model; - if (auto status = LiteRtCreateCompiledModel(env.Get(), litert_model, - compilation_options.release(), - &compiled_model); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to create compiled model"); - } + LITERT_RETURN_IF_ERROR(LiteRtCreateCompiledModel( + env.Get(), litert_model, compilation_options.release(), + &compiled_model)); return CompiledModel(litert_model, compiled_model); } static Expected Create( litert::Environment& env, litert::Model& model, LiteRtHwAccelerators hardware_accelerator = kLiteRtHwAcceleratorCpu) { - LITERT_ASSIGN_OR_RETURN(auto options, Options::Create()); + LITERT_ASSIGN_OR_RETURN(Options options, Options::Create()); options.SetHardwareAccelerators(hardware_accelerator); return Create(env, model, std::move(options)); } @@ -154,23 +153,17 @@ class CompiledModel Expected GetInputBufferRequirements( size_t signature_index, size_t input_index) const { LiteRtTensorBufferRequirements buffer_requirements; - if (auto status = LiteRtGetCompiledModelInputBufferRequirements( - Get(), signature_index, input_index, &buffer_requirements); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get input buffer requirements"); - } + LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelInputBufferRequirements( + Get(), signature_index, input_index, &buffer_requirements)); return TensorBufferRequirements(buffer_requirements, /*owned=*/false); } // The same as above except this function takes input tensor name. Expected GetInputBufferRequirements( size_t signature_index, absl::string_view input_name) const { - auto signature = model_.GetSignature(signature_index); - auto input_index = FindInputIndex(signature_index, input_name); - if (!input_index) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input"); - } - return GetInputBufferRequirements(signature_index, *input_index); + LITERT_ASSIGN_OR_RETURN(size_t input_index, + FindInputIndex(signature_index, input_name)); + return GetInputBufferRequirements(signature_index, input_index); } // Returns the buffer requirements for the given output tensor. The returned @@ -179,23 +172,17 @@ class CompiledModel Expected GetOutputBufferRequirements( size_t signature_index, size_t output_index) const { LiteRtTensorBufferRequirements buffer_requirements; - if (auto status = LiteRtGetCompiledModelOutputBufferRequirements( - Get(), signature_index, output_index, &buffer_requirements); - status != kLiteRtStatusOk) { - return Unexpected(status, "Failed to get output buffer requirements"); - } + LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelOutputBufferRequirements( + Get(), signature_index, output_index, &buffer_requirements)); return TensorBufferRequirements(buffer_requirements, /*owned=*/false); } // The same as above except this function takes output tensor name. Expected GetOutputBufferRequirements( size_t signature_index, absl::string_view output_name) const { - auto signature = model_.GetSignature(signature_index); - auto output_index = FindOutputIndex(signature_index, output_name); - if (!output_index) { - return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output"); - } - return GetOutputBufferRequirements(signature_index, *output_index); + LITERT_ASSIGN_OR_RETURN(size_t output_index, + FindOutputIndex(signature_index, output_name)); + return GetOutputBufferRequirements(signature_index, output_index); } // Creates an input tensor buffer for the given signature and input name.