diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 80af5a52..81c91ddb 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -18,4 +18,6 @@ def register(): """Register the NPU platform.""" + # To ensure that the module is correctly replaced, add it at the beginning + import vllm_ascend.patch_module # noqa: F401 return "vllm_ascend.platform.NPUPlatform" diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py new file mode 100644 index 00000000..e6020f36 --- /dev/null +++ b/vllm_ascend/patch/__init__.py @@ -0,0 +1,17 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import vllm_ascend.patch.patch_minicpm # noqa diff --git a/vllm_ascend/patch/patch_minicpm.py b/vllm_ascend/patch/patch_minicpm.py new file mode 100644 index 00000000..4828ccc7 --- /dev/null +++ b/vllm_ascend/patch/patch_minicpm.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from vllm.attention import AttentionMetadata +from vllm.model_executor.models.minicpm import MiniCPMAttention + + +def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, +) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +# The type conversion in the forward function is deleted to support the rope operator. +MiniCPMAttention.forward = forward diff --git a/vllm_ascend/patch_module.py b/vllm_ascend/patch_module.py new file mode 100644 index 00000000..2bbedac1 --- /dev/null +++ b/vllm_ascend/patch_module.py @@ -0,0 +1,23 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +import types + +# prevent errors caused by triton not supported +sys.modules[ + 'vllm.model_executor.layers.fused_moe.fused_moe'] = types.ModuleType( + 'fused_moe_module') diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker.py index 27930d4d..98f53aca 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker.py @@ -68,8 +68,9 @@ def __init__( is_driver_worker: bool = False, model_runner_cls: Optional[Type[ModelRunnerBase]] = None, ) -> None: - # Register ops when worker init. + # Register ops and patch when worker init. from vllm_ascend import ops # noqa: F401 + from vllm_ascend import patch # noqa: F401 WorkerBase.__init__(self, vllm_config=vllm_config) # Try to import mindie_turbo to accelerate vLLM inference.