Skip to content

Commit

Permalink
[Build] Make CI work on 0.7.3-dev (#140) (#172)
Browse files Browse the repository at this point in the history
### What this PR does / why we need it?
This PR resolves the issue with inference on the Ray backend. For more
details, see
[here](#92).

### Does this PR introduce _any_ user-facing change?
no.

### How was this patch tested?
Validation was performed based on v0.7.3, and the specific validation
script can be found
[here](#92).

---------

Signed-off-by: Chenguang Li <757486878@qq.com>
  • Loading branch information
noemotiovon authored Feb 27, 2025
1 parent b074047 commit 8a62c1f
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
2 changes: 1 addition & 1 deletion examples/offline_distributed_inference_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
llm = LLM(
model="Qwen/Qwen2.5-0.5B-Instruct",
tensor_parallel_size=2,
distributed_executor_backend="mp",
distributed_executor_backend="ray",
trust_remote_code=True,
)

Expand Down
22 changes: 22 additions & 0 deletions vllm_ascend/patch/ray_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch_npu # noqa: F401
import vllm
from vllm.executor.ray_utils import RayWorkerWrapper

if RayWorkerWrapper is not None:

class NPURayWorkerWrapper(RayWorkerWrapper):
"""Importing torch_npu in other Ray processes through an empty class and
a monkey patch.
When Ray performs a remote call, it serializes the Task or Actor and passes
it to the Worker process, where it is deserialized and executed.
If no patch is applied, the default code of the RayWorkerWrapper provided
by vLLM is used, which does not import torch_npu, causing an error in the
Worker process.
See https://github.com/vllm-project/vllm-ascend/pull/92.
"""

pass

vllm.executor.ray_utils.RayWorkerWrapper = NPURayWorkerWrapper
3 changes: 3 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def mem_get_info(cls) -> Tuple[int, int]:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# RayWorkerWrapper monkey patch when setup
from vllm_ascend.patch import ray_patch # noqa: F401

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"
Expand Down

0 comments on commit 8a62c1f

Please sign in to comment.