Skip to content

Commit

Permalink
Add signature-key based API for CC Compiled Model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729679764
  • Loading branch information
tensorflower-gardener committed Feb 21, 2025
1 parent 3666a04 commit 2c74f04
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tensorflow/lite/experimental/litert/cc/litert_compiled_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ class CompiledModel
return Create(env, model, std::move(options));
}

// Get input buffer requirements for the given signature and input name.
Expected<TensorBufferRequirements> GetInputBufferRequirements(
absl::string_view signature_name, absl::string_view input_name) {
LITERT_ASSIGN_OR_RETURN(size_t signature_index,
model_.GetSignatureIndex(signature_name));
return GetInputBufferRequirements(signature_index, input_name);
}

// Returns the buffer requirements for the given n-th input tensor. The
// returned TensorBufferRequirements is used to create the input tensor
// buffer.
Expand All @@ -166,6 +174,14 @@ class CompiledModel
return GetInputBufferRequirements(signature_index, input_index);
}

// Get output buffer requirements for the given signature and output name.
Expected<TensorBufferRequirements> GetOutputBufferRequirements(
absl::string_view signature_name, absl::string_view output_name) {
LITERT_ASSIGN_OR_RETURN(size_t signature_index,
model_.GetSignatureIndex(signature_name));
return GetOutputBufferRequirements(signature_index, output_name);
}

// Returns the buffer requirements for the given output tensor. The returned
// TensorBufferRequirements is used to create the output tensor
// buffer.
Expand Down Expand Up @@ -199,6 +215,16 @@ class CompiledModel
/*is_input=*/false);
}

// A helper function to create input tensor buffers for the given signature.
// It uses BufferRequirements and RankedTensorType to create the input tensor
// buffers.
Expected<std::vector<TensorBuffer>> CreateInputBuffers(
absl::string_view signature_name) const {
LITERT_ASSIGN_OR_RETURN(size_t signature_index,
model_.GetSignatureIndex(signature_name));
return CreateInputOutputBuffers(signature_index, /*is_input=*/true);
}

// A helper function to creates the input tensor buffers for the given
// signature. It uses BufferRequirements and RankedTensorType to create the
// input tensor buffers.
Expand All @@ -207,6 +233,16 @@ class CompiledModel
return CreateInputOutputBuffers(signature_index, /*is_input=*/true);
}

// A helper function to create output tensor buffers for the given signature.
// It uses BufferRequirements and RankedTensorType to create the output tensor
// buffers.
Expected<std::vector<TensorBuffer>> CreateOutputBuffers(
absl::string_view signature_name) const {
LITERT_ASSIGN_OR_RETURN(size_t signature_index,
model_.GetSignatureIndex(signature_name));
return CreateOutputBuffers(signature_index);
}

// A helper function to creates the output tensor buffers for the given
// signature. It uses BufferRequirements and RankedTensorType to create the
// output tensor buffers.
Expand Down Expand Up @@ -236,6 +272,16 @@ class CompiledModel
return RunHelper(signature_index, input_buffers, output_buffers, async);
}

// Runs the model of the given signature key synchronously with the provided
// input/output TensorBuffers.
Expected<void> Run(absl::string_view signature_key,
const std::vector<TensorBuffer>& input_buffers,
const std::vector<TensorBuffer>& output_buffers) const {
LITERT_ASSIGN_OR_RETURN(size_t signature_index,
model_.GetSignatureIndex(signature_key));
return Run(signature_index, input_buffers, output_buffers);
}

// Runs the model of the given signature key synchronously with the provided
// input/output TensorBuffer map.
Expected<void> Run(
Expand Down

0 comments on commit 2c74f04

Please sign in to comment.