diff --git a/all_models/bert/ensemble/1/.tmp b/all_models/bert/ensemble/1/.tmp new file mode 100644 index 00000000..e69de29b diff --git a/all_models/bert/ensemble/config.pbtxt b/all_models/bert/ensemble/config.pbtxt new file mode 100755 index 00000000..9cc0e0bc --- /dev/null +++ b/all_models/bert/ensemble/config.pbtxt @@ -0,0 +1,115 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 200 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "out_logits" + data_type: TYPE_FP32 + dims: [ -1] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "BAD_WORDS_DICT" + value: "bad_words" + } + input_map { + key: "STOP_WORDS_DICT" + value: "stop_words" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "STOP_WORDS_IDS" + value: "_STOP_WORDS_IDS" + } + output_map { + key: "BAD_WORDS_IDS" + value: "_BAD_WORDS_IDS" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "stop_words_list" + value: "_STOP_WORDS_IDS" + } + input_map { + key: "bad_words_list" + value: "_BAD_WORDS_IDS" + } + output_map { + key: "logits" + value: "out_logits" + } + } + ] +} diff --git a/all_models/bert/preprocessing/1/model.py b/all_models/bert/preprocessing/1/model.py new file mode 100644 index 00000000..343adbc4 --- /dev/null +++ b/all_models/bert/preprocessing/1/model.py @@ -0,0 +1,256 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +from typing import List + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + self.logger = pb_utils.Logger + self.logger.log_info("Info Msg!") + + model_config = json.loads(args['model_config']) + tokenizer_dir = model_config['parameters']['tokenizer_dir'][ + 'string_value'] + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, trust_remote_code=True) + + self.max_len = int(model_config['parameters']['max_length']['string_value']) + # Parse model output configs and convert Triton types to numpy types + output_names = [ + "INPUT_ID", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS" + ] + + for output_name in output_names: + setattr( + self, + output_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name( + model_config, output_name)['data_type'])) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + # logger = pb_utils.Logger + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, + 'QUERY').as_numpy() + # self.logger.log(f'query shape: {query.shape}, query: {query}', self.logger.INFO) + # batch_dim = query.shape[0] + # if batch_dim != 1: + + # err_str = "Inflight batching backend expects requests with batch size of 1." + # logger.log_error(err_str) + # responses.append( + # pb_utils.InferenceResponse( + # output_tensors=[], + # error=pb_utils.TritonError(err_str))) + # continue + + + bad_words_dict = pb_utils.get_input_tensor_by_name( + request, 'BAD_WORDS_DICT') + if bad_words_dict is not None: + bad_words_dict = bad_words_dict.as_numpy() + + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, 'STOP_WORDS_DICT') + if stop_words_dict is not None: + stop_words_dict = stop_words_dict.as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = self._to_word_list_format(bad_words_dict) + stop_words = self._to_word_list_format(stop_words_dict) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + 'INPUT_ID', input_id.astype(self.input_id_dtype)) + request_input_len_tensor = pb_utils.Tensor( + 'REQUEST_INPUT_LEN', + request_input_len.astype(self.request_input_len_dtype)) + + bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words) + stop_words_ids_tensor = pb_utils.Tensor('STOP_WORDS_IDS', + stop_words) + + inference_response = pb_utils.InferenceResponse(output_tensors=[ + input_id_tensor, request_input_len_tensor, bad_words_ids_tensor, stop_words_ids_tensor, + ]) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print('Cleaning up...') + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + + input_ids_with_padding = self.tokenizer( + [query[0][0].decode('utf-8')], padding='max_length', max_length=self.max_len) + input_ids_without_padding = self.tokenizer( + [query[0][0].decode('utf-8')]) + + input_ids = np.array(input_ids_with_padding['input_ids']).astype(int) + input_lengths = [len(x) for x in input_ids_without_padding['input_ids']] + input_lengths = np.array([[x] for x in input_lengths]).astype(int) + # self.logger.log(f'input_lengths.shape: {input_lengths.shape}, input_ids: {input_ids}, input_lengths: {input_lengths}', self.logger.INFO) + + return input_ids, input_lengths + + def _to_word_list_format(self, word_lists: List[List[str | bytes]]): + ''' + word_lists format: + len(word_lists) == batch_size + word_lists[i] means the words associated to batch item i. A "word" may actually be any string. Like "lorem" or "lorem ipsum". + ''' + assert self.tokenizer != None, "need to set tokenizer" + + if word_lists is None: + # Return an empty array of shape (1,2,0) + return np.empty([1, 2, 0], dtype="int32") + + flat_ids = [] + offsets = [] + for word_list in word_lists: + item_flat_ids = [] + item_offsets = [] + + for word in word_list: + if isinstance(word, bytes): + word = word.decode() + + ids = self.tokenizer.encode(word, add_special_tokens=False) + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), + constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), + constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose( + (1, 0, 2)) + + def _get_embedding_bias(self, embedding_bias_words, embedding_bias_weights, + bias_dtype): + + assert self.tokenizer != None, "need to set tokenizer" + + if embedding_bias_words is None or embedding_bias_weights is None: + return np.empty([1, 0], dtype=self.embedding_bias_weights_dtype) + + batch_embedding_bias = [] + for words, weights in zip(embedding_bias_words, + embedding_bias_weights): + + vocab_size = self.tokenizer.vocab_size + embedding_bias = [0.] * vocab_size + + assert len(words) == len( + weights + ), "Embedding bias words must have same dimension as embedding bias weights" + + for word, weight in zip(words, weights): + if isinstance(word, bytes): + word = word.decode() + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + for id in ids: + embedding_bias[id] += weight + + batch_embedding_bias.append(np.array(embedding_bias)) + + return np.array(batch_embedding_bias, dtype=bias_dtype) diff --git a/all_models/bert/preprocessing/config.pbtxt b/all_models/bert/preprocessing/config.pbtxt new file mode 100644 index 00000000..79783f95 --- /dev/null +++ b/all_models/bert/preprocessing/config.pbtxt @@ -0,0 +1,105 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 200 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: ${tokenizer_dir} + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "auto" + } +} + +parameters { + key: "add_special_tokens" + value: { + string_value: "False" + } +} + +parameters { + key: "max_length" + value: { + string_value: "128" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/all_models/bert/tensorrt_llm/1/model.py b/all_models/bert/tensorrt_llm/1/model.py new file mode 100644 index 00000000..d3036a6c --- /dev/null +++ b/all_models/bert/tensorrt_llm/1/model.py @@ -0,0 +1,220 @@ +import json +import os + +import torch +import triton_python_backend_utils as pb_utils +from torch import from_numpy + +import tensorrt_llm +from tensorrt_llm.runtime import Session +from tensorrt_llm.runtime import TensorInfo +from tensorrt_llm.functional import str_dtype_to_trt +import tensorrt as trt + +def trt_dtype_to_torch(dtype): + if dtype == trt.float16: + return torch.float16 + elif dtype == trt.float32: + return torch.float32 + elif dtype == trt.int32: + return torch.int32 + else: + raise TypeError("%s is not supported" % dtype) + +def mpi_comm(): + from mpi4py import MPI + return MPI.COMM_WORLD + + +def mpi_rank(): + return mpi_comm().Get_rank() + + +def get_engine_name(model, dtype, tp_size, rank): + return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank) + + +def get_input_tensor_by_name(request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is not None: + # Triton tensor -> numpy tensor -> PyTorch tensor + return from_numpy(tensor.as_numpy()) + else: + return tensor + + +def get_input_scalar_by_name(request, name): + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is not None: + # Triton tensor -> numpy tensor -> first scalar + tensor = tensor.as_numpy() + return tensor.reshape((tensor.size, ))[0] + else: + return tensor + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + self.logger = pb_utils.Logger + self.logger.log_info("Info Msg!") + + model_config = json.loads(args['model_config']) + engine_dir = model_config['parameters']['engine_dir']['string_value'] + + config_path = os.path.join(engine_dir, 'config.json') + with open(config_path, 'r') as f: + config = json.load(f) + dtype = config['builder_config']['precision'] + world_size = config['builder_config']['tensor_parallel'] + assert world_size == tensorrt_llm.mpi_world_size(), \ + f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})' + + model_name = config['builder_config']['name'] + runtime_rank = tensorrt_llm.mpi_rank() if world_size > 1 else 0 + + runtime_mapping = tensorrt_llm.Mapping(world_size, + runtime_rank, + tp_size=world_size) + serialize_path = get_engine_name(model_name, dtype, world_size, + runtime_rank) + serialize_path = os.path.join(engine_dir, serialize_path) + + self.stream = torch.cuda.current_stream().cuda_stream + print(f'Loading engine from {serialize_path}') + with open(serialize_path, 'rb') as f: + engine_buffer = f.read() + print(f'Creating session from engine') + self.session = Session.from_serialized_engine(engine_buffer) + + self.comm = mpi_comm() + self.rank = mpi_rank() + torch.cuda.set_device(self.rank % runtime_mapping.gpus_per_node) + + if self.rank != 0: + while (True): + self.execute([None]) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + responses = [] + + # Every Python backend must iterate through list of requests and create + # an instance of pb_utils.InferenceResponse class for each of them. You + # should avoid storing any of the input Tensors in the class attributes + # as they will be overridden in subsequent inference requests. You can + # make a copy of the underlying NumPy array and store it if it is + # required. + for request in requests: + # Perform inference on the request and append it to responses list... + inputs = {} + if self.rank == 0: + inputs['input_ids'] = get_input_tensor_by_name( + request, 'input_ids') + inputs['input_lengths'] = get_input_tensor_by_name( + request, 'input_lengths') + + # Broadcast requests to other clients + inputs = self.comm.bcast(inputs, root=0) + + input_ids = inputs['input_ids'].cuda() + + batch_dim = input_ids.shape[0] + input_len = input_ids.shape[1] + input_lengths = inputs['input_lengths'].cuda() + + inputs = { + 'input_ids': input_ids, + 'input_lengths': input_lengths, + # 'token_type_ids': token_type_ids + } + + # print(f'input_ids.size(): {input_ids.size()}') + output_info = self.session.infer_shapes([ + TensorInfo('input_ids', str_dtype_to_trt('int32'), + (batch_dim, input_len)), + TensorInfo('input_lengths', str_dtype_to_trt('int32'), (batch_dim, )) + ]) + # self.session._print_engine_info() + + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + output_name = 'logits' + assert output_name in outputs, f'{output_name} not found in outputs, check if build.py set the name correctly' + ok = self.session.run(inputs, outputs, self.stream) + + assert ok, "Runtime execution failed" + + logits = outputs[output_name] + logits = logits.to(dtype=torch.float32) + + if self.rank == 0: + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + torch.cuda.synchronize() + self.logger.log(f'logits: {logits.cpu().numpy()}', self.logger.INFO) + logits = [ + pb_utils.Tensor("logits", + logits.cpu().numpy()) + ] + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + + inference_response = pb_utils.InferenceResponse(logits) + else: + inference_response = pb_utils.InferenceResponse([]) + responses.append(inference_response) + + # You must return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + return diff --git a/all_models/bert/tensorrt_llm/config.pbtxt b/all_models/bert/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..819613c5 --- /dev/null +++ b/all_models/bert/tensorrt_llm/config.pbtxt @@ -0,0 +1,88 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "python" +max_batch_size: 200 + +model_transaction_policy { + decoupled: false +} + +dynamic_batching { + preferred_batch_size: [ 200 ] + max_queue_delay_microseconds: 2000 +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + allow_ragged_batch: true + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + } +] +output [ + { + name: "logits" + data_type: TYPE_FP32 + dims: [ -1] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters { + key: "engine_dir" + value: { + string_value: ${engine_dir} + } +} + +