Skip to content

Commit 2c1bd84

Browse files
fyabcywang96wangxiongts
authored
[Model][VLM] Add Qwen2.5-Omni model support (thinker only) (#15130)
Signed-off-by: fyabc <suyang.fy@alibaba-inc.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: Xiong Wang <wangxiongts@163.com>
1 parent 5c91212 commit 2c1bd84

File tree

23 files changed

+1852
-82
lines changed

23 files changed

+1852
-82
lines changed

docs/source/models/supported_models.md

+15
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,13 @@ See [this page](#generative-models) for more information on how to use generativ
10401040
* ✅︎
10411041
* ✅︎
10421042
* ✅︎
1043+
- * `Qwen2_5OmniThinkerForConditionalGeneration`
1044+
* Qwen2.5-Omni
1045+
* T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup>
1046+
* `Qwen/Qwen2.5-Omni-7B`
1047+
*
1048+
* ✅︎
1049+
* ✅︎\*
10431050
- * `SkyworkR1VChatModel`
10441051
* Skywork-R1V-38B
10451052
* T + I
@@ -1109,6 +1116,14 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
11091116
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
11101117
:::
11111118

1119+
:::{note}
1120+
To use Qwen2.5-Omni, you have to install a fork of Hugging Face Transformers library from source via
1121+
`pip install git+https://github.com/BakerBunker/transformers.git@qwen25omni`.
1122+
1123+
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
1124+
`--mm-processor-kwargs '{"use_audio_in_video": True}'`.
1125+
:::
1126+
11121127
### Pooling Models
11131128

11141129
See [this page](pooling-models) for more information on how to use pooling models.

examples/offline_inference/audio_language.py

+31
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
130130
)
131131

132132

133+
# Qwen2.5-Omni
134+
def run_qwen2_5_omni(question: str, audio_count: int):
135+
model_name = "Qwen/Qwen2.5-Omni-7B"
136+
137+
engine_args = EngineArgs(
138+
model=model_name,
139+
max_model_len=4096,
140+
max_num_seqs=5,
141+
limit_mm_per_prompt={"audio": audio_count},
142+
)
143+
144+
audio_in_prompt = "".join([
145+
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
146+
])
147+
148+
default_system = (
149+
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
150+
"Group, capable of perceiving auditory and visual inputs, as well as "
151+
"generating text and speech.")
152+
153+
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
154+
"<|im_start|>user\n"
155+
f"{audio_in_prompt}{question}<|im_end|>\n"
156+
"<|im_start|>assistant\n")
157+
return ModelRequestData(
158+
engine_args=engine_args,
159+
prompt=prompt,
160+
)
161+
162+
133163
# Ultravox 0.5-1B
134164
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
135165
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
@@ -182,6 +212,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
182212
"minicpmo": run_minicpmo,
183213
"phi4_mm": run_phi4mm,
184214
"qwen2_audio": run_qwen2_audio,
215+
"qwen2_5_omni": run_qwen2_5_omni,
185216
"ultravox": run_ultravox,
186217
"whisper": run_whisper,
187218
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Qwen2.5-Omni Offline Inference Examples
2+
3+
This folder provides several example scripts on how to inference Qwen2.5-Omni offline.
4+
5+
## Thinker Only
6+
7+
```bash
8+
# Audio + image + video
9+
python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities
10+
11+
# Read vision and audio inputs from a single video file
12+
# NOTE: V1 engine does not support interleaved modalities yet.
13+
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video
14+
15+
# Multiple audios
16+
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios
17+
```
18+
19+
This script will run the thinker part of Qwen2.5-Omni, and generate text response.
20+
21+
You can also test Qwen2.5-Omni on a single modality:
22+
23+
```bash
24+
# Process audio inputs
25+
python examples/offline_inference/audio_language.py --model-type qwen2_5_omni
26+
27+
# Process image inputs
28+
python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni
29+
30+
# Process video inputs
31+
python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni
32+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
This example shows how to use vLLM for running offline inference
4+
with the correct prompt format on Qwen2.5-Omni (thinker only).
5+
"""
6+
7+
from typing import NamedTuple
8+
9+
import vllm.envs as envs
10+
from vllm import LLM, SamplingParams
11+
from vllm.assets.audio import AudioAsset
12+
from vllm.assets.image import ImageAsset
13+
from vllm.assets.video import VideoAsset
14+
from vllm.utils import FlexibleArgumentParser
15+
16+
17+
class QueryResult(NamedTuple):
18+
inputs: dict
19+
limit_mm_per_prompt: dict[str, int]
20+
21+
22+
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
23+
# lower-end GPUs.
24+
# Unless specified, these settings have been tested to work on a single L4.
25+
26+
default_system = (
27+
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
28+
"Group, capable of perceiving auditory and visual inputs, as well as "
29+
"generating text and speech.")
30+
31+
32+
def get_mixed_modalities_query() -> QueryResult:
33+
question = ("What is recited in the audio? "
34+
"What is the content of this image? Why is this video funny?")
35+
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
36+
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
37+
"<|vision_bos|><|IMAGE|><|vision_eos|>"
38+
"<|vision_bos|><|VIDEO|><|vision_eos|>"
39+
f"{question}<|im_end|>\n"
40+
f"<|im_start|>assistant\n")
41+
return QueryResult(
42+
inputs={
43+
"prompt": prompt,
44+
"multi_modal_data": {
45+
"audio":
46+
AudioAsset("mary_had_lamb").audio_and_sample_rate,
47+
"image":
48+
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
49+
"video":
50+
VideoAsset(name="sample_demo_1.mp4",
51+
num_frames=16).np_ndarrays,
52+
},
53+
},
54+
limit_mm_per_prompt={
55+
"audio": 1,
56+
"image": 1,
57+
"video": 1
58+
},
59+
)
60+
61+
62+
def get_use_audio_in_video_query() -> QueryResult:
63+
question = ("Describe the content of the video, "
64+
"then convert what the baby say into text.")
65+
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
66+
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
67+
f"{question}<|im_end|>\n"
68+
f"<|im_start|>assistant\n")
69+
asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16)
70+
audio = asset.get_audio(sampling_rate=16000)
71+
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
72+
"Please launch this example with "
73+
"`VLLM_USE_V1=0`.")
74+
return QueryResult(
75+
inputs={
76+
"prompt": prompt,
77+
"multi_modal_data": {
78+
"video": asset.np_ndarrays,
79+
"audio": audio,
80+
},
81+
"mm_processor_kwargs": {
82+
"use_audio_in_video": True,
83+
},
84+
},
85+
limit_mm_per_prompt={
86+
"audio": 1,
87+
"video": 1
88+
},
89+
)
90+
91+
92+
def get_multi_audios_query() -> QueryResult:
93+
question = "Are these two audio clips the same?"
94+
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
95+
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
96+
"<|audio_bos|><|AUDIO|><|audio_eos|>"
97+
f"{question}<|im_end|>\n"
98+
f"<|im_start|>assistant\n")
99+
return QueryResult(
100+
inputs={
101+
"prompt": prompt,
102+
"multi_modal_data": {
103+
"audio": [
104+
AudioAsset("winning_call").audio_and_sample_rate,
105+
AudioAsset("mary_had_lamb").audio_and_sample_rate,
106+
],
107+
},
108+
},
109+
limit_mm_per_prompt={
110+
"audio": 2,
111+
},
112+
)
113+
114+
115+
query_map = {
116+
"mixed_modalities": get_mixed_modalities_query,
117+
"use_audio_in_video": get_use_audio_in_video_query,
118+
"multi_audios": get_multi_audios_query,
119+
}
120+
121+
122+
def main(args):
123+
model_name = "Qwen/Qwen2.5-Omni-7B"
124+
query_result = query_map[args.query_type]()
125+
126+
llm = LLM(model=model_name,
127+
max_model_len=5632,
128+
max_num_seqs=5,
129+
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
130+
seed=args.seed)
131+
132+
# We set temperature to 0.2 so that outputs can be different
133+
# even when all prompts are identical when running batch inference.
134+
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
135+
136+
outputs = llm.generate(query_result.inputs,
137+
sampling_params=sampling_params)
138+
139+
for o in outputs:
140+
generated_text = o.outputs[0].text
141+
print(generated_text)
142+
143+
144+
if __name__ == "__main__":
145+
parser = FlexibleArgumentParser(
146+
description='Demo on using vLLM for offline inference with '
147+
'audio language models')
148+
parser.add_argument('--query-type',
149+
'-q',
150+
type=str,
151+
default="mixed_modalities",
152+
choices=query_map.keys(),
153+
help='Query type.')
154+
parser.add_argument("--seed",
155+
type=int,
156+
default=None,
157+
help="Set the seed when initializing `vllm.LLM`.")
158+
159+
args = parser.parse_args()
160+
main(args)

examples/offline_inference/vision_language.py

+37
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,42 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
941941
)
942942

943943

944+
# Qwen2.5-Omni
945+
def run_qwen2_5_omni(questions: list[str], modality: str):
946+
model_name = "Qwen/Qwen2.5-Omni-7B"
947+
948+
engine_args = EngineArgs(
949+
model=model_name,
950+
max_model_len=4096,
951+
max_num_seqs=5,
952+
mm_processor_kwargs={
953+
"min_pixels": 28 * 28,
954+
"max_pixels": 1280 * 28 * 28,
955+
"fps": [1],
956+
},
957+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
958+
)
959+
960+
if modality == "image":
961+
placeholder = "<|IMAGE|>"
962+
elif modality == "video":
963+
placeholder = "<|VIDEO|>"
964+
965+
default_system = (
966+
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
967+
"Group, capable of perceiving auditory and visual inputs, as well as "
968+
"generating text and speech.")
969+
970+
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n"
971+
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
972+
f"{question}<|im_end|>\n"
973+
"<|im_start|>assistant\n") for question in questions]
974+
return ModelRequestData(
975+
engine_args=engine_args,
976+
prompts=prompts,
977+
)
978+
979+
944980
# SkyworkR1V
945981
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
946982
assert modality == "image"
@@ -1010,6 +1046,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
10101046
"qwen_vl": run_qwen_vl,
10111047
"qwen2_vl": run_qwen2_vl,
10121048
"qwen2_5_vl": run_qwen2_5_vl,
1049+
"qwen2_5_omni": run_qwen2_5_omni,
10131050
"skywork_chat": run_skyworkr1v,
10141051
"smolvlm": run_smolvlm,
10151052
}

tests/models/decoder_only/vision_language/test_models.py

+17
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,23 @@
139139
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
140140
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
141141
),
142+
"qwen2_5_omni": VLMTestInfo(
143+
models=["Qwen/Qwen2.5-Omni-7B"],
144+
test_type=(
145+
VLMTestType.IMAGE,
146+
VLMTestType.MULTI_IMAGE,
147+
VLMTestType.VIDEO
148+
),
149+
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
150+
img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501
151+
video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501
152+
max_model_len=4096,
153+
max_num_seqs=2,
154+
auto_cls=AutoModelForVision2Seq,
155+
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
156+
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
157+
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
158+
),
142159
#### Extended model tests
143160
"aria": VLMTestInfo(
144161
models=["rhymes-ai/Aria"],

tests/models/multimodal/processing/test_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
280280
"Qwen/Qwen2-VL-2B-Instruct",
281281
"Qwen/Qwen2.5-VL-3B-Instruct",
282282
"Qwen/Qwen2-Audio-7B-Instruct",
283+
"Qwen/Qwen2.5-Omni-7B",
283284
"Skywork/Skywork-R1V-38B",
284285
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
285286
"openai/whisper-large-v3",

tests/models/registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@ def check_available_online(
362362
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
363363
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
364364
min_transformers_version="4.49"), # noqa: E501
365+
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
366+
min_transformers_version="4.52"), # noqa: E501
365367
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
366368
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
367369
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501

vllm/assets/video.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22

33
from dataclasses import dataclass
44
from functools import lru_cache
5-
from typing import Literal
5+
from typing import Literal, Optional
66

77
import cv2
88
import numpy as np
99
import numpy.typing as npt
1010
from huggingface_hub import hf_hub_download
1111
from PIL import Image
1212

13+
from vllm.utils import PlaceholderModule
14+
1315
from .base import get_cache_dir
1416

17+
try:
18+
import librosa
19+
except ImportError:
20+
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
21+
1522

1623
@lru_cache
1724
def download_video_asset(filename: str) -> str:
@@ -85,3 +92,12 @@ def np_ndarrays(self) -> npt.NDArray:
8592
video_path = download_video_asset(self.name)
8693
ret = video_to_ndarrays(video_path, self.num_frames)
8794
return ret
95+
96+
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
97+
"""
98+
Read audio data from the video asset, used in Qwen2.5-Omni examples.
99+
100+
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
101+
"""
102+
video_path = download_video_asset(self.name)
103+
return librosa.load(video_path, sr=sampling_rate)[0]

0 commit comments

Comments
 (0)