Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tritonserver Fails to Start with TensorRT-LLM Backend with lookahead_decoding mode - Assertion Failure in lookaheadDecodingLayer.cpp #710

Open
2 of 4 tasks
shaylapid opened this issue Feb 18, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@shaylapid
Copy link

shaylapid commented Feb 18, 2025

System Info

  • CPU architecture: x86_64
  • GPU NVIDIA H100 80GB
  • TensorRT-LLM backend tag: v0.17.0
  • Container used: nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3
  • OS Debian GNU/Linux 11 (bullseye)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Build the model:

Start the container:

docker run --rm -it --net host --shm-size=2g \
    --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \
    -v </path/to/git/tensorrtllm_backend>:/tensorrtllm_backend \
    -v </path/to/engines>:/model/engine \
    -v <path/to/hf-checkpoint>:/model/src \
    nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3

Quantize the model:

cd /tensorrtllm_backend/tensorrt_llm/examples/quantization;
python quantize.py \
    --model_dir /model/src  \
    --qformat fp8 \
    --kv_cache_dtype fp8 \
    --output_dir /model/build

Build:

trtllm-build \
    --checkpoint_dir /model/build \
    --output_dir /model/engine \
    --gpt_attention_plugin auto \
    --gemm_plugin fp8 \
    --gemm_swiglu_plugin fp8 \
    --low_latency_gemm_swiglu_plugin fp8 \
    --remove_input_padding enable \
    --context_fmha enable \
    --max_beam_width 1 \
    --max_num_tokens 1000 \
    --max_seq_len 250 \
    --max_input_len 200 \
    --max_batch_size 4 \
    --use_fused_mlp enable \
    --use_fp8_context_fmha enable \
    --use_paged_context_fmha enable \
    --speculative_decoding_mode lookahead_decoding \
    --max_draft_len 39

Adapt model repo:

Adding the following to config.pbtext:

parameters: {
  key: "decoding_mode"
  value: {
    string_value: "lookahead"
  }
}

Run with Tritonserver:

Start the container:

docker run --rm -it --net host --shm-size=2g \
    --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \
    -v <path/to/model>:/models \
    nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3

start tritonserver

tritonserver --model-repository=/models

Expected behavior

Tritonserver should start successfully, and model inference should be available.

actual behavior

Tritonserver fails to start with the following assertion error:

E0218 20:57:33.147956 130 model_lifecycle.cc:654] "failed to load 'tensorrt_llm_2beam' version 1: Internal: unexpected error when creating modelInstanceState: [TensorRT-LLM][ERROR] Assertion failed: 16 != 40 (/workspace/tensorrt_llm/cpp/tensorrt_llm/layers/lookaheadDecodingLayer.cpp:56)\n1 0x7ff34f6bdff8 tensorrt_llm::common::throwRuntimeError(char const*, int, std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&) + 95\n2 0x7ff34f9d890c tensorrt_llm::layers::LookaheadDecodingLayer<__half>::CpuAlgorithmResources::CpuAlgorithmResources(tensorrt_llm::layers::DecoderDomain const&) + 4396\n3 0x7ff34f9d90c1 tensorrt_llm::layers::LookaheadDecodingLayer<__half>::LookaheadDecodingLayer(tensorrt_llm::layers::DecoderDomain const&, std::shared_ptr<tensorrt_llm::runtime::BufferManager>) + 241\n4 0x7ff34f97e862 tensorrt_llm::layers::DecodingLayer<__half>::DecodingLayer(tensorrt_llm::executor::DecodingMode const&, tensorrt_llm::layers::DecoderDomain const&, std::shared_ptr<tensorrt_llm::runtime::BufferManager>) + 978\n5 0x7ff34f994c88 tensorrt_llm::layers::DynamicDecodeLayer<__half>::initializeLayers() + 872\n6 0x7ff34f995bf9 tensorrt_llm::layers::DynamicDecodeLayer<__half>::initialize() + 1321\n7 0x7ff34f995dfa tensorrt_llm::layers::DynamicDecodeLayer<__half>::DynamicDecodeLayer(tensorrt_llm::executor::DecodingMode const&, tensorrt_llm::layers::DecoderDomain const&, std::shared_ptr<tensorrt_llm::runtime::BufferManager>) + 202\n8 0x7ff34fa8da0b tensorrt_llm::runtime::GptDecoder<__half>::GptDecoder(tensorrt_llm::executor::DecodingMode const&, unsigned long, unsigned long, unsigned long, unsigned long, unsigned long, std::shared_ptr<tensorrt_llm::runtime::CudaStream> const&, std::shared_ptr<tensorrt_llm::runtime::SpeculativeDecodingModule const>) + 603\n9 0x7ff34fa9a1bc tensorrt_llm::runtime::GptDecoderBatched::setup(tensorrt_llm::executor::DecodingMode const&, int, int, int, int, int, int, nvinfer1::DataType, tensorrt_llm::runtime::ModelConfig const&) + 3372\n10 0x7ff3504e7a99 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::createDecoder(std::optional<tensorrt_llm::executor::DecodingMode> const&) + 825\n11 0x7ff3504fdec0 tensorrt_llm::batch_manager::TrtGptModelInflightBatching::TrtGptModelInflightBatching(std::shared_ptrnvinfer1::ILogger, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::runtime::RawEngine const&, bool, tensorrt_llm::batch_manager::TrtGptModelOptionalParams const&) + 3168\n12 0x7ff350476df9 tensorrt_llm::batch_manager::TrtGptModelFactory::create(tensorrt_llm::runtime::RawEngine const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::batch_manager::TrtGptModelType, tensorrt_llm::batch_manager::TrtGptModelOptionalParams const&) + 489\n13 0x7ff350597369 tensorrt_llm::executor::Executor::Impl::createModel(tensorrt_llm::runtime::RawEngine const&, tensorrt_llm::runtime::ModelConfig const&, tensorrt_llm::runtime::WorldConfig const&, tensorrt_llm::executor::ExecutorConfig const&) + 185\n14 0x7ff3505979fd tensorrt_llm::executor::Executor::Impl::loadModel(std::optionalstd::filesystem::__cxx11::path const&, std::optional<std::basic_string_view<unsigned char, std::char_traits > > const&, tensorrt_llm::runtime::GptJsonConfig const&, tensorrt_llm::executor::ExecutorConfig const&, bool, std::optional<std::map<std::__cxx11::basic_string<char, std::char_traits, std::allocator >, tensorrt_llm::executor::Tensor, std::less<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits, std::allocator > const, tensorrt_llm::executor::Tensor> > > > const&) + 1229\n15 0x7ff350598c4a tensorrt_llm::executor::Executor::Impl::Impl(std::filesystem::__cxx11::path const&, std::optionalstd::filesystem::__cxx11::path const&, tensorrt_llm::executor::ModelType, tensorrt_llm::executor::ExecutorConfig const&) + 2474\n16 0x7ff35057e6d7 tensorrt_llm::executor::Executor::Executor(std::filesystem::__cxx11::path const&, tensorrt_llm::executor::ModelType, tensorrt_llm::executor::ExecutorConfig const&) + 87\n17 0x7ff5e803588e /opt/tritonserver/backends/tensorrtllm/libtriton_tensorrtllm.so(+0x3388e) [0x7ff5e803588e]\n18 0x7ff5e8032049 triton::backend::inflight_batcher_llm::ModelInstanceState::ModelInstanceState(triton::backend::inflight_batcher_llm::ModelState*, TRITONBACKEND_ModelInstance*) + 2185\n19 0x7ff5e8032592 triton::backend::inflight_batcher_llm::ModelInstanceState::Create(triton::backend::inflight_batcher_llm::ModelState*, TRITONBACKEND_ModelInstance*, triton::backend::inflight_batcher_llm::ModelInstanceState**) + 66\n20 0x7ff5e801f929 TRITONBACKEND_ModelInstanceInitialize + 153\n21 0x7ff5f6bd7649 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x1a1649) [0x7ff5f6bd7649]\n22 0x7ff5f6bd80d2 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x1a20d2) [0x7ff5f6bd80d2]\n23 0x7ff5f6bbdcf3 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x187cf3) [0x7ff5f6bbdcf3]\n24 0x7ff5f6bbe0a4 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x1880a4) [0x7ff5f6bbe0a4]\n25 0x7ff5f6bc768d /opt/tritonserver/bin/../lib/libtritonserver.so(+0x19168d) [0x7ff5f6bc768d]\n26 0x7ff5f6134ec3 /usr/lib/x86_64-linux-gnu/libc.so.6(+0xa1ec3) [0x7ff5f6134ec3]\n27 0x7ff5f6bb4f02 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x17ef02) [0x7ff5f6bb4f02]\n28 0x7ff5f6bc2ddc /opt/tritonserver/bin/../lib/libtritonserver.so(+0x18cddc) [0x7ff5f6bc2ddc]\n29 0x7ff5f6bc6e12 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x190e12) [0x7ff5f6bc6e12]\n30 0x7ff5f6cc78e1 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x2918e1) [0x7ff5f6cc78e1]\n31 0x7ff5f6ccac3c /opt/tritonserver/bin/../lib/libtritonserver.so(+0x294c3c) [0x7ff5f6ccac3c]\n32 0x7ff5f6e27305 /opt/tritonserver/bin/../lib/libtritonserver.so(+0x3f1305) [0x7ff5f6e27305]\n33 0x7ff5f6391db4 /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xecdb4) [0x7ff5f6391db4]\n34 0x7ff5f612fa94 /usr/lib/x86_64-linux-gnu/libc.so.6(+0x9ca94) [0x7ff5f612fa94]\n35 0x7ff5f61bca34 __clone + 68"
I0218 20:57:33.148431 130 model_lifecycle.cc:789] "failed to load 'tensorrt_llm_2beam'"
I0218 20:57:33.148569 130 server.cc:604]

additional notes

Changing --max_draft_len to 15 allows Tritonserver to start, but this prevents selecting the desired max_draft_len value.

@shaylapid shaylapid added the bug Something isn't working label Feb 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant