-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Communicator] Add monkey patch #30
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Closed
MengqingCao
reviewed
Feb 11, 2025
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
4779824
to
b496ef7
Compare
test scriptsI made some modification so that we could run test on npu. (will make a pr on vLLM) """Test the communication operators.
Run `pytest tests/distributed/test_comm_ops.py`.
"""
import os
import pytest
import ray
import torch
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, cleanup_dist_env_and_memory)
from vllm.platforms import current_platform
from ..utils import init_test_distributed_environment, multi_process_parallel
DEVICE_TYPE = current_platform.device_type
@ray.remote(num_gpus=0, resources={current_platform.ray_device_key: 1}, max_calls=1)
def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the ASCEND_RT_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
# del os.environ["ASCEND_RT_VISIBLE_DEVICES"]
# os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
# os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1"
device = torch.device(f"{DEVICE_TYPE}:{rank}")
torch.npu.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port, backend="hccl")
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="npu") *
(r + 1) for r in range(tp_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
torch.testing.assert_close(t, expected)
# @ray.remote(num_gpus=1, max_calls=1)
@ray.remote(num_gpus=0, resources={current_platform.ray_device_key: 1}, max_calls=1)
def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the ASCEND_RT_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
# del os.environ["ASCEND_RT_VISIBLE_DEVICES"]
device = torch.device(f"{DEVICE_TYPE}:{rank}")
torch.npu.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port, backend="hccl")
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s
for all_gather_dimension in range(num_dimensions):
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="npu").reshape(tensor_size) * (r + 1)
for r in range(tp_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
torch.testing.assert_close(t, expected)
# @ray.remote(num_gpus=1, max_calls=1)
@ray.remote(num_gpus=0, resources={current_platform.ray_device_key: 1}, max_calls=1)
def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the ASCEND_RT_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
# del os.environ["ASCEND_RT_VISIBLE_DEVICES"]
device = torch.device(f"{DEVICE_TYPE}:{rank}")
torch.npu.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port, backend="hccl")
test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="npu"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="npu"),
}
if (rank % tp_size) == 0:
broadcast_tensor_dict(test_dict, src=0)
else:
recv_dict = broadcast_tensor_dict(src=0)
assert len(recv_dict) == len(test_dict)
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
# @ray.remote(num_gpus=1, max_calls=1)
@ray.remote(num_gpus=0, resources={current_platform.ray_device_key: 1}, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# del os.environ["ASCEND_RT_VISIBLE_DEVICES"]
device = torch.device(f"{DEVICE_TYPE}:{rank}")
torch.npu.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port, backend="hccl")
test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="npu"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="npu"),
}
if not get_pp_group().is_first_rank:
recv_dict = get_pp_group().recv_tensor_dict()
if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(test_dict)
if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
# @ray.remote(num_gpus=1, max_calls=1)
@ray.remote(num_gpus=0, resources={current_platform.ray_device_key: 1}, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
# del os.environ["ASCEND_RT_VISIBLE_DEVICES"]
device = torch.device(f"{DEVICE_TYPE}:{rank}")
torch.npu.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port, backend="hccl")
size = torch.tensor([64]).to("npu")
test_tensor = torch.arange(64, dtype=torch.float32, device="npu")
if not get_pp_group().is_first_rank:
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
if not get_pp_group().is_last_rank:
get_pp_group().send(test_tensor)
if not get_pp_group().is_first_rank:
torch.testing.assert_close(test_tensor, recv_tensor)
@pytest.mark.skipif(torch.npu.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("test_target", [
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_parallel(tp_size, 1, test_target)
@pytest.mark.skipif(torch.npu.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
"test_target", [send_recv_test_worker,
# send_recv_tensor_dict_test_worker
])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)
@pytest.mark.skipif(torch.npu.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize("test_target", [
send_recv_test_worker, send_recv_tensor_dict_test_worker,
all_reduce_test_worker, all_gather_test_worker,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel_pipeline_parallel(
tp_size, pp_size, test_target):
multi_process_parallel(tp_size, pp_size, test_target) test resultDue to limitation of npu resource, there are some skipped cases (need 4 card). (vllm) cmq@cmq-docker:~/code/vllm-cpu/vllm$ pytest tests/distributed/test_comm_ops.py
============================================================================= test session starts ==============================================================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/cmq/code/vllm-cpu/vllm
configfile: pyproject.toml
plugins: anyio-4.8.0, typeguard-4.3.0
collected 9 items
tests/distributed/test_comm_ops.py ....sssss [100%]
=============================================================================== warnings summary ===============================================================================
../../../miniconda3/envs/vllm/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:3
/home/cmq/miniconda3/envs/vllm/lib/python3.10/site-packages/torch_npu/dynamo/torchair/__init__.py:3: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
import pkg_resources
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================= 4 passed, 5 skipped, 1 warning in 129.40s (0:02:09) ============================================================== |
LGTM, thx! |
Yikun
approved these changes
Feb 11, 2025
wangxiyuan
added a commit
to wangxiyuan/vllm-ascend
that referenced
this pull request
Feb 11, 2025
Some PR for plugin support is not merged by vllm yet. This PR add monkey patch to vllm-ascend to make vllm-ascend work with vllm directly. This patch code should be removed once the related function is supported by vllm originally. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
wangxiyuan
added a commit
that referenced
this pull request
Feb 11, 2025
Some PR for plugin support is not merged by vllm yet. This PR add monkey patch to vllm-ascend to make vllm-ascend work with vllm directly. This patch code should be removed once the related function is supported by vllm originally. cherry pick to 0.7.1 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
wangxiyuan
pushed a commit
that referenced
this pull request
Feb 17, 2025
### What this PR does / why we need it? Revert communicator patch as vllm-project/vllm#13208 has been merged. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? test locally by #30 (comment) Signed-off-by: MengqingCao <cmq0113@163.com>
Angazenn
pushed a commit
to Angazenn/vllm-ascend
that referenced
this pull request
Feb 21, 2025
Some PR for plugin support is not merged by vllm yet. This PR add monkey patch to vllm-ascend to make vllm-ascend work with vllm directly. This patch code should be removed once the related function is supported by vllm originally. cherry pick to 0.7.1 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Some PR for plugin support is not merged by vllm yet. This PR add monkey patch to vllm-ascend to make vllm-ascend work with vllm directly.
This patch code should be removed once the related function is supported by vllm originally.