diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 1b90841fc2ab..18b23b3949f2 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -4560,7 +4560,7 @@ def generate( return thinker_result # 2. Generate speech tokens from talker module - embeds_to_talker = thinker_result.hidden_states[0][0].clone() + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) if thinker_kwargs.get("input_features", None) is not None: audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device) diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index c0c105032355..294a72c6b6cc 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -4243,7 +4243,7 @@ def generate( return thinker_result # 2. Generate speech tokens from talker module - embeds_to_talker = thinker_result.hidden_states[0][0].clone() + embeds_to_talker = thinker_result.hidden_states[0][0].clone().to(self.talker.device) if thinker_kwargs.get("input_features", None) is not None: audio_ids_mask = input_ids == self.config.thinker_config.audio_token_index audio_mask = audio_ids_mask.unsqueeze(-1).expand_as(embeds_to_talker).to(embeds_to_talker.device)