From b88443b6c645942b89991c3df35f5485630e8df3 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Fri, 14 Feb 2025 10:45:49 +0800 Subject: [PATCH] [dist] fix communicator patch (#58) ### What this PR does / why we need it? fix communicator patch so parallel could work. see #52 Signed-off-by: MengqingCao --- setup.py | 2 +- vllm_ascend/patch/patch_commnicator.py | 6 +++--- vllm_ascend/platform.py | 3 +-- vllm_ascend/worker.py | 2 ++ 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 2553521d..d278ef99 100644 --- a/setup.py +++ b/setup.py @@ -95,7 +95,7 @@ def _read_requirements(filename: str) -> List[str]: "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ], - packages=find_packages(exclude=("docs", "examples", "tests*", "patch")), + packages=find_packages(exclude=("docs", "examples", "tests*")), python_requires=">=3.9", install_requires=get_requirements(), extras_require={}, diff --git a/vllm_ascend/patch/patch_commnicator.py b/vllm_ascend/patch/patch_commnicator.py index 45f34954..15a8563f 100644 --- a/vllm_ascend/patch/patch_commnicator.py +++ b/vllm_ascend/patch/patch_commnicator.py @@ -19,11 +19,11 @@ # https://github.com/vllm-project/vllm/pull/11324. import torch -from vllm.distributed.parallel_state import GroupCoordinator +import vllm from vllm.utils import resolve_obj_by_qualname -class GroupCoordinatorPatch(GroupCoordinator): +class GroupCoordinatorPatch(vllm.distributed.parallel_state.GroupCoordinator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -66,4 +66,4 @@ def all_gather(self, input_, dim=-1): return self.communicator.all_gather(input_, dim) -GroupCoordinator = GroupCoordinatorPatch +vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 242cf528..2b847de1 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -88,9 +88,8 @@ def mem_get_info(cls) -> Tuple[int, int]: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - # Register ops and patch when setup. + # Register ops when setup. from vllm_ascend import ops # noqa: F401 - from vllm_ascend import patch # noqa: F401 parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker.py index c5884e36..cecff11e 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker.py @@ -457,6 +457,8 @@ def init_worker_distributed_environment( backend: str = "hccl") -> None: """Initialize the distributed environment.""" set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + # register communicator patch before init dist env + from vllm_ascend import patch # noqa: F401 init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, backend)