diff --git a/examples/offline_distributed_inference_npu.py b/examples/offline_distributed_inference_npu.py index 88533786..937bac1f 100644 --- a/examples/offline_distributed_inference_npu.py +++ b/examples/offline_distributed_inference_npu.py @@ -32,7 +32,7 @@ llm = LLM( model="Qwen/Qwen2.5-0.5B-Instruct", tensor_parallel_size=2, - distributed_executor_backend="mp", + distributed_executor_backend="ray", trust_remote_code=True, ) diff --git a/vllm_ascend/patch/ray_patch.py b/vllm_ascend/patch/ray_patch.py new file mode 100644 index 00000000..8fce6c72 --- /dev/null +++ b/vllm_ascend/patch/ray_patch.py @@ -0,0 +1,22 @@ +import torch_npu # noqa: F401 +import vllm +from vllm.executor.ray_utils import RayWorkerWrapper + +if RayWorkerWrapper is not None: + + class NPURayWorkerWrapper(RayWorkerWrapper): + """Importing torch_npu in other Ray processes through an empty class and + a monkey patch. + + When Ray performs a remote call, it serializes the Task or Actor and passes + it to the Worker process, where it is deserialized and executed. + + If no patch is applied, the default code of the RayWorkerWrapper provided + by vLLM is used, which does not import torch_npu, causing an error in the + Worker process. + See https://github.com/vllm-project/vllm-ascend/pull/92. + """ + + pass + + vllm.executor.ray_utils.RayWorkerWrapper = NPURayWorkerWrapper \ No newline at end of file diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 45647b02..a4a915ff 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -103,6 +103,9 @@ def mem_get_info(cls) -> Tuple[int, int]: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + # RayWorkerWrapper monkey patch when setup + from vllm_ascend.patch import ray_patch # noqa: F401 + parallel_config = vllm_config.parallel_config if parallel_config.worker_cls == "auto": parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"