Skip to content

[Model][VLM] Add Qwen2.5-Omni model support (thinker only) #15130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Apr 19, 2025

Conversation

fyabc
Copy link
Contributor

@fyabc fyabc commented Mar 19, 2025

This PR adding support for Qwen2.5-Omni model (thinker only).

Requirements

This PR requires this corresponding transformers PR.

pip install git+https://github.com/huggingface/transformers@f742a644ca32e65758c3adb36225aef1731bd2a8

Note: You need to install transformers from source from that branch

Example Usage

# Audio + image + video
python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities

# Read vision and audio inputs from a single video file
# NOTE: V1 engine does not support interleaved modalities yet.
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video

# Process audio inputs
python examples/offline_inference/audio_language.py --model-type qwen2_5_omni

# Process image inputs
python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni

# Process video inputs
python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni

Notes

The whole Qwen2.5-Omni model includes three parts:

  • thinker: multimodal inputs -> text responses & hidden states
  • talker: text responses & hidden states from thinker -> speech codes
  • code2wav (streaming codec decoder): codes -> speech

This PR only implements the thinker part now, it accepts multimodal inputs (images / videos / audios), and generate text responses, similar to other common VLMs.
We have also develped an end-to-end implementation (will be released soon), but due to its significant impact on the vLLM framework architecture, we will not create the related pull request for now.

FIX #15563

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) v1 labels Mar 19, 2025
@DarkLight1337
Copy link
Member

Sorry I don't have time to review in detail tonight, but from a quick glance, can you add this model to the following pages?

  • Supported Models page
  • tests/models/registry.py (set is_available_online=False to pass CI until the model repo is released on HF)
  • tests/models/multimodal/processing/test_common.py
  • tests/models/decoder_only/vision_language/test_models.py (optional for now)

@fyabc
Copy link
Contributor Author

fyabc commented Mar 19, 2025

Sorry I don't have time to review in detail tonight, but from a quick glance, can you add this model to the following pages?

  • Supported Models page
  • tests/models/registry.py (set is_available_online=False to pass CI until the model repo is released on HF)
  • tests/models/multimodal/processing/test_common.py
  • tests/models/decoder_only/vision_language/test_models.py (optional for now)

OK,I will add them tomorrow.

@ywang96 ywang96 self-assigned this Mar 19, 2025
@yangninghua
Copy link

@fyabc Qwen/Qwen2.5-Omni-7B ??

@ywang96
Copy link
Member

ywang96 commented Mar 21, 2025

Sorry for the delay - going to take a look at this PR tonight!

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution! I have left some comments!

@fyabc
Copy link
Contributor Author

fyabc commented Mar 24, 2025

Hi @ywang96 @DarkLight1337 , I update some other examples here, please check the code.

@fyabc
Copy link
Contributor Author

fyabc commented Apr 18, 2025

Hi fyabc,

Thank you for your excellent work on adding support for Qwen2.5-Omni in vLLM!

I noticed that the current instructions don’t fully work for my setup:

# Process audio inputs
python examples/offline_inference/audio_language.py --model-type qwen2_5_omni

Specifically, I encountered the following error during inference:

NFO 04-17 15:51:03 [llm_engine.py:449] init engine (profile, create kv cache, warmup model) took 7.81 seconds
[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/inputs/registry.py", line 167, in call_hf_processor
[rank0]:     return hf_processor(**data, **merged_kwargs, return_tensors="pt")
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/r1aqa/lib/python3.11/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py", line 192, in __call__
[rank0]:     text = self.replace_multimodal_special_tokens(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/r1aqa/lib/python3.11/site-packages/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py", line 234, in replace_multimodal_special_tokens
[rank0]:     sample = sample.replace(self.audio_token, "<|audio_placeholder|>" * next(audio_lengths), 1)
[rank0]:                                                                         ^^^^^^^^^^^^^^^^^^^
[rank0]: StopIteration

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/mnt/ssd3/chunhui/vllm/examples/offline_inference/audio_language.py", line 302, in <module>
[rank0]:     main(args)
[rank0]:   File "/mnt/ssd3/chunhui/vllm/examples/offline_inference/audio_language.py", line 289, in main
[rank0]:     outputs = llm.generate(
[rank0]:               ^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/utils.py", line 1181, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/entrypoints/llm.py", line 462, in generate
[rank0]:     self._validate_and_add_requests(
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/entrypoints/llm.py", line 1356, in _validate_and_add_requests
[rank0]:     self._add_request(
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/entrypoints/llm.py", line 1374, in _add_request
[rank0]:     self.llm_engine.add_request(
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/utils.py", line 1181, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/engine/llm_engine.py", line 781, in add_request
[rank0]:     preprocessed_inputs = self.input_preprocessor.preprocess(
[rank0]:                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/inputs/preprocess.py", line 750, in preprocess
[rank0]:     return self._process_decoder_only_prompt(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/inputs/preprocess.py", line 699, in _process_decoder_only_prompt
[rank0]:     prompt_comps = self._prompt_to_llm_inputs(
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/inputs/preprocess.py", line 370, in _prompt_to_llm_inputs
[rank0]:     return self._process_multimodal(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/inputs/preprocess.py", line 275, in _process_multimodal
[rank0]:     return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/model_executor/models/qwen2_5_omni_thinker.py", line 335, in apply
[rank0]:     ) = self._cached_apply_hf_processor(
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/multimodal/processing.py", line 1389, in _cached_apply_hf_processor
[rank0]:     ) = self._apply_hf_processor_main(
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/multimodal/processing.py", line 1330, in _apply_hf_processor_main
[rank0]:     prompt_ids = self._apply_hf_processor_text_only(prompt)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/multimodal/processing.py", line 1258, in _apply_hf_processor_text_only
[rank0]:     prompt_ids, _, _ = self._apply_hf_processor_text_mm(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/multimodal/processing.py", line 1228, in _apply_hf_processor_text_mm
[rank0]:     processed_data = self._call_hf_processor(
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/model_executor/models/qwen2_5_omni_thinker.py", line 275, in _call_hf_processor
[rank0]:     hf_inputs = super()._call_hf_processor(
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/multimodal/processing.py", line 1191, in _call_hf_processor
[rank0]:     return self.info.ctx.call_hf_processor(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/mnt/ssd3/chunhui/vllm/vllm/inputs/registry.py", line 172, in call_hf_processor
[rank0]:     raise RuntimeError(msg) from exc
[rank0]: RuntimeError: Failed to apply Qwen2_5OmniProcessor on data={'text': '<|im_start|>system\nYou are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.<|im_end|>\n<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>\nWhat is recited in the audio?<|im_end|>\n<|im_start|>assistant\n'} with kwargs={}
[rank0]:[W417 15:51:08.818304976 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

After some debugging, I found a workaround that allows your vLLM branch to work smoothly:

  1. Install the latest version of your vLLM with [Model][VLM] Add Qwen2.5-Omni model support (thinker only).
  2. Install the latest transformers via:
    uv pip install git+https://github.com/huggingface/transformers
  3. Replace the src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py file with [this version](https://github.com/BakerBunker/transformers/blob/21dbefaa54e5bf180464696aa70af0bfc7a61d53/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py).

After making this replacement, everything works as expected.

I suspect the issue arises because the latest version of transformers uses a processing_qwen2_5_omni.py implementation that is not yet compatible with your vLLM processor interface. If my understanding is incorrect, please feel free to clarify.

Thanks again for your great contribution!

Best,
Chunhui

Thank you for your response! It is a known issue (you can check #15130 (comment) for more details), and i am working on it.

@fyabc
Copy link
Contributor Author

fyabc commented Apr 18, 2025

Hi @chunhuizng , I have updated the PR and fix the bug, please try again.

@fyabc
Copy link
Contributor Author

fyabc commented Apr 18, 2025

Is there anything else that needs to be done?

@DarkLight1337 Bug fixed and test all passed, I think this PR is able to merge.

@DarkLight1337 DarkLight1337 added ready ONLY add when PR is ready to merge/full CI is needed and removed structured-output speculative-decoding labels Apr 18, 2025
@fyabc fyabc requested a review from ywang96 April 18, 2025 11:56
@DarkLight1337
Copy link
Member

DarkLight1337 commented Apr 18, 2025

Can you resolve the failures in the basic models test?

fyabc added 2 commits April 18, 2025 22:38
Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
@fyabc
Copy link
Contributor Author

fyabc commented Apr 18, 2025

Can you resolve the failures in the basic models test?

Hi @DarkLight1337, I have fixed the test registry, now the api timeout error seems raised outside of thie PR.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very sorry for the long delay - let's get this in!

@ywang96 ywang96 merged commit 2c1bd84 into vllm-project:main Apr 19, 2025
49 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Multi-modal Model Requests Apr 19, 2025
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
dtransposed pushed a commit to dtransposed/vllm that referenced this pull request Apr 22, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
liuzijing2014 pushed a commit to liuzijing2014/vllm that referenced this pull request Apr 25, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
liuzijing2014 pushed a commit to liuzijing2014/vllm that referenced this pull request Apr 25, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
wuisawesome pushed a commit to character-tech/vllm that referenced this pull request Apr 28, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
s3woz pushed a commit to s3woz/vllm that referenced this pull request Apr 30, 2025
…ect#15130)

Signed-off-by: fyabc <suyang.fy@alibaba-inc.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Xiong Wang <wangxiongts@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

[New Model]: please surport for Qwen/Qwen2.5-Omni-7B