diff --git a/vllm_ascend/communicator.py b/vllm_ascend/communicator.py index efef46e9..afb39f7f 100644 --- a/vllm_ascend/communicator.py +++ b/vllm_ascend/communicator.py @@ -17,12 +17,62 @@ import torch import torch.distributed as dist -from vllm.distributed.device_communicators.base_communicator import \ - CommunicatorBase -class NPUCommunicator(CommunicatorBase): +class NPUCommunicator: + + def __init__(self, group, unique_name=""): + self.group = group + self.unique_name = unique_name + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(self.group) + self.ranks = dist.get_process_group_ranks(self.group) + global_rank = dist.get_rank() + self.rank_in_group = dist.get_group_rank(self.group, global_rank) def all_reduce(self, x: torch.Tensor) -> torch.Tensor: dist.all_reduce(x, group=self.group) return x + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1): + # NOTE: We assume that the input tensor is on the same device across + # all the ranks. + # NOTE: `dst` is the local rank of the destination rank. + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [ + torch.empty_like(input_) for _ in range(self.world_size) + ] + else: + gather_list = None + # Gather. + dist.gather(input_, gather_list, dst=self.ranks[dst], group=self.group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + dist.all_gather_into_tensor(output_tensor, input_, group=self.group) + # Reshape + output_tensor = output_tensor.reshape((self.world_size, ) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (self.world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py new file mode 100644 index 00000000..f03d4b4f --- /dev/null +++ b/vllm_ascend/patch/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vllm_ascend.patch import patch_commnicator # noqa diff --git a/vllm_ascend/patch/patch_commnicator.py b/vllm_ascend/patch/patch_commnicator.py new file mode 100644 index 00000000..45f34954 --- /dev/null +++ b/vllm_ascend/patch/patch_commnicator.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file is used to monkey patch communicator in vllm to support ascend. +# Remove this file when vllm support by +# https://github.com/vllm-project/vllm/pull/11324. + +import torch +from vllm.distributed.parallel_state import GroupCoordinator +from vllm.utils import resolve_obj_by_qualname + + +class GroupCoordinatorPatch(GroupCoordinator): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.device = torch.device(f"npu:{self.local_rank}") + + from vllm.platforms import current_platform + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls()) + # we have checked and ensure that reusing tpu tag here is fine. + use_custom_device = kwargs.get("use_tpu_communicator", False) + if use_custom_device and self.world_size > 1: + self.communicator = device_comm_cls(group=self.device_group, + unique_name=self.unique_name) + + def all_reduce(self, input_): + # Bypass the function if we are using only 1 device. + if self.world_size == 1: + return input_ + + return self.communicator.all_reduce(input_) + + def gather(self, input_, dst=0, dim=-1): + # Bypass the function if we are using only 1 device. + if self.world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + return self.communicator.gather(input_, dst, dim) + + def all_gather(self, input_, dim=-1): + # Bypass the function if we are using only 1 device. + if self.world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + return self.communicator.all_gather(input_, dim) + + +GroupCoordinator = GroupCoordinatorPatch diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2b847de1..242cf528 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -88,8 +88,9 @@ def mem_get_info(cls) -> Tuple[int, int]: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - # Register ops when setup. + # Register ops and patch when setup. from vllm_ascend import ops # noqa: F401 + from vllm_ascend import patch # noqa: F401 parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto":