Skip to content

Commit

Permalink
Refactor CC Compiled Model buffer requirements functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729675545
  • Loading branch information
tensorflower-gardener committed Feb 21, 2025
1 parent 3bcb15f commit 3666a04
Showing 1 changed file with 16 additions and 29 deletions.
45 changes: 16 additions & 29 deletions tensorflow/lite/experimental/litert/cc/litert_compiled_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "tensorflow/lite/experimental/litert/c/litert_compiled_model.h"
#include "tensorflow/lite/experimental/litert/c/litert_compiled_model_options.h"
#include "tensorflow/lite/experimental/litert/c/litert_environment.h"
#include "tensorflow/lite/experimental/litert/c/litert_model.h"
#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h"
#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h"
#include "tensorflow/lite/experimental/litert/cc/litert_environment.h"
#include "tensorflow/lite/experimental/litert/cc/litert_expected.h"
Expand Down Expand Up @@ -131,19 +133,16 @@ class CompiledModel
Options&& compilation_options) {
LiteRtModel litert_model = model.Get();
LiteRtCompiledModel compiled_model;
if (auto status = LiteRtCreateCompiledModel(env.Get(), litert_model,
compilation_options.release(),
&compiled_model);
status != kLiteRtStatusOk) {
return Unexpected(status, "Failed to create compiled model");
}
LITERT_RETURN_IF_ERROR(LiteRtCreateCompiledModel(
env.Get(), litert_model, compilation_options.release(),
&compiled_model));
return CompiledModel(litert_model, compiled_model);
}

static Expected<CompiledModel> Create(
litert::Environment& env, litert::Model& model,
LiteRtHwAccelerators hardware_accelerator = kLiteRtHwAcceleratorCpu) {
LITERT_ASSIGN_OR_RETURN(auto options, Options::Create());
LITERT_ASSIGN_OR_RETURN(Options options, Options::Create());
options.SetHardwareAccelerators(hardware_accelerator);
return Create(env, model, std::move(options));
}
Expand All @@ -154,23 +153,17 @@ class CompiledModel
Expected<TensorBufferRequirements> GetInputBufferRequirements(
size_t signature_index, size_t input_index) const {
LiteRtTensorBufferRequirements buffer_requirements;
if (auto status = LiteRtGetCompiledModelInputBufferRequirements(
Get(), signature_index, input_index, &buffer_requirements);
status != kLiteRtStatusOk) {
return Unexpected(status, "Failed to get input buffer requirements");
}
LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelInputBufferRequirements(
Get(), signature_index, input_index, &buffer_requirements));
return TensorBufferRequirements(buffer_requirements, /*owned=*/false);
}

// The same as above except this function takes input tensor name.
Expected<TensorBufferRequirements> GetInputBufferRequirements(
size_t signature_index, absl::string_view input_name) const {
auto signature = model_.GetSignature(signature_index);
auto input_index = FindInputIndex(signature_index, input_name);
if (!input_index) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find input");
}
return GetInputBufferRequirements(signature_index, *input_index);
LITERT_ASSIGN_OR_RETURN(size_t input_index,
FindInputIndex(signature_index, input_name));
return GetInputBufferRequirements(signature_index, input_index);
}

// Returns the buffer requirements for the given output tensor. The returned
Expand All @@ -179,23 +172,17 @@ class CompiledModel
Expected<TensorBufferRequirements> GetOutputBufferRequirements(
size_t signature_index, size_t output_index) const {
LiteRtTensorBufferRequirements buffer_requirements;
if (auto status = LiteRtGetCompiledModelOutputBufferRequirements(
Get(), signature_index, output_index, &buffer_requirements);
status != kLiteRtStatusOk) {
return Unexpected(status, "Failed to get output buffer requirements");
}
LITERT_RETURN_IF_ERROR(LiteRtGetCompiledModelOutputBufferRequirements(
Get(), signature_index, output_index, &buffer_requirements));
return TensorBufferRequirements(buffer_requirements, /*owned=*/false);
}

// The same as above except this function takes output tensor name.
Expected<TensorBufferRequirements> GetOutputBufferRequirements(
size_t signature_index, absl::string_view output_name) const {
auto signature = model_.GetSignature(signature_index);
auto output_index = FindOutputIndex(signature_index, output_name);
if (!output_index) {
return Unexpected(kLiteRtStatusErrorNotFound, "Failed to find output");
}
return GetOutputBufferRequirements(signature_index, *output_index);
LITERT_ASSIGN_OR_RETURN(size_t output_index,
FindOutputIndex(signature_index, output_name));
return GetOutputBufferRequirements(signature_index, output_index);
}

// Creates an input tensor buffer for the given signature and input name.
Expand Down

0 comments on commit 3666a04

Please sign in to comment.