diff --git a/examples/offline_distributed_inference_npu.py b/examples/offline_distributed_inference_npu.py index 88533786..937bac1f 100644 --- a/examples/offline_distributed_inference_npu.py +++ b/examples/offline_distributed_inference_npu.py @@ -32,7 +32,7 @@ llm = LLM( model="Qwen/Qwen2.5-0.5B-Instruct", tensor_parallel_size=2, - distributed_executor_backend="mp", + distributed_executor_backend="ray", trust_remote_code=True, ) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 1f22e564..2af9d7ea 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -90,6 +90,8 @@ def mem_get_info(cls) -> Tuple[int, int]: def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Register ops when setup. from vllm_ascend import ops # noqa: F401 + # RayWorkerWrapper monkey patch when setup + from vllm_ascend import ray_patch # noqa: F401 parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": diff --git a/vllm_ascend/ray_patch.py b/vllm_ascend/ray_patch.py new file mode 100644 index 00000000..4cbd778f --- /dev/null +++ b/vllm_ascend/ray_patch.py @@ -0,0 +1,10 @@ +import vllm +from vllm.executor.ray_utils import RayWorkerWrapper +import torch_npu # noqa: F401 + +class NPURayWorkerWrapper(RayWorkerWrapper): + """Importing torch_npu in other Ray processes through an empty class and a monkey patch. + """ + pass + +vllm.executor.ray_utils.RayWorkerWrapper = NPURayWorkerWrapper