From 4888910fe749ab9b0e0c39e8c7f90435c26f0af8 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 11 Feb 2025 14:48:41 +0800 Subject: [PATCH] Make code more clean --- docs/cn/installation.zh.md | 38 ++++++++++++ vllm_ascend/attention.py | 6 +- vllm_ascend/model_runner.py | 7 ++- vllm_ascend/platform.py | 34 ++++------- vllm_ascend/utils.py | 2 + vllm_ascend/worker.py | 119 ++++++++++++++++-------------------- 6 files changed, 112 insertions(+), 94 deletions(-) create mode 100644 docs/cn/installation.zh.md diff --git a/docs/cn/installation.zh.md b/docs/cn/installation.zh.md new file mode 100644 index 00000000..ceddf608 --- /dev/null +++ b/docs/cn/installation.zh.md @@ -0,0 +1,38 @@ +### 昇腾NPU环境准备 + +### 依赖 +| 需求 | 支持的版本 | 推荐版本 | 注意 | +|-------------|-------------------| ----------- |------------------------------------------| +| vLLM | main | main | vllm-ascend 依赖 | +| Python | >= 3.9 | [3.10](https://www.python.org/downloads/) | vllm 依赖 | +| CANN | >= 8.0.RC2 | [8.0.RC3](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.0.beta1) | vllm-ascend and torch-npu 依赖 | +| torch-npu | >= 2.4.0 | [2.5.1rc1](https://gitee.com/ascend/pytorch/releases/tag/v6.0.0.alpha001-pytorch2.5.1) | vllm-ascend 依赖 | +| torch | >= 2.4.0 | [2.5.1](https://github.com/pytorch/pytorch/releases/tag/v2.5.1) | torch-npu and vllm 依赖 | + + +以下为安装推荐版本软件的简短说明: + +#### 容器化安装 + +您可以直接使用[容器镜像](https://hub.docker.com/r/ascendai/cann),只需一行命令即可: + +```bash +docker run \ + --name vllm-ascend-env \ + --device /dev/davinci1 \ + --device /dev/davinci_manager \ + --device /dev/devmm_svm \ + --device /dev/hisi_hdc \ + -v /usr/local/dcmi:/usr/local/dcmi \ + -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \ + -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \ + -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \ + -v /etc/ascend_install.info:/etc/ascend_install.info \ + -it quay.io/ascend/cann:8.0.rc3.beta1-910b-ubuntu22.04-py3.10 bash +``` + +您无需手动安装 `torch` 和 `torch_npu` ,它们将作为 `vllm-ascend` 依赖项自动安装。 + +#### 手动安装 + +您也可以选择手动安装,按照[昇腾安装指南](https://ascend.github.io/docs/sources/ascend/quick_install.html)中提供的说明配置环境。 diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 2f9b5e70..7e7a33ad 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -22,11 +22,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch - -try: - import torch_npu # noqa: F401 -except ImportError: - print("Failed to import torch_npu.") +import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, diff --git a/vllm_ascend/model_runner.py b/vllm_ascend/model_runner.py index 77e093b5..e7056da2 100644 --- a/vllm_ascend/model_runner.py +++ b/vllm_ascend/model_runner.py @@ -47,7 +47,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) -from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -61,6 +60,8 @@ _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm_ascend.platform import NPUPlatform + if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -1303,7 +1304,7 @@ def need_send_kv(self, model_input, kv_caches) -> bool: return self.vllm_config.kv_transfer_config.is_kv_producer and ( not is_profile_run) and is_prefill_run - @current_platform.inference_mode() + @NPUPlatform.inference_mode() def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) @@ -1380,7 +1381,7 @@ def profile_run(self) -> None: dtype=self.model_config.dtype, device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) - current_platform.synchronize() + NPUPlatform.synchronize() return def get_model(self) -> nn.Module: diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2b847de1..77a2a624 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -19,29 +19,14 @@ from typing import Optional, Tuple import torch - -try: - import torch_npu # noqa: F401 -except ImportError: - print("Failed to import torch_npu.") +import torch_npu # noqa: F401 from vllm.config import VllmConfig from vllm.platforms import Platform, PlatformEnum -os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1" - +from vllm_ascend.utils import ASCEND_RT_VISIBLE_DEVICES -def _device_id_to_physical_device_id(device_id: int) -> int: - if "ASCEND_RT_VISIBLE_DEVICES" in os.environ: - device_ids = os.environ["ASCEND_RT_VISIBLE_DEVICES"].split(",") - if device_ids == [""]: - raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty" - "string, which means Ascend NPU support is" - "disabled.") - physical_device_id = device_ids[device_id] - return int(physical_device_id) - else: - return device_id +os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1" class NPUPlatform(Platform): @@ -51,7 +36,7 @@ class NPUPlatform(Platform): device_type: str = "npu" simple_compile_backend: str = "npu" ray_device_key: str = "NPU" - device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" + device_control_env_var: str = ASCEND_RT_VISIBLE_DEVICES @classmethod def get_device_capability(cls, device_id: int = 0): @@ -59,8 +44,15 @@ def get_device_capability(cls, device_id: int = 0): @classmethod def get_device_name(cls, device_id: int = 0) -> str: - physical_device_id = _device_id_to_physical_device_id(device_id) - return torch.npu.get_device_name(physical_device_id) + if ASCEND_RT_VISIBLE_DEVICES in os.environ: + device_ids = os.environ[ASCEND_RT_VISIBLE_DEVICES].split(",") + if device_ids == [""]: + raise RuntimeError("ASCEND_RT_VISIBLE_DEVICES is set to empty" + "string, which means Ascend NPU support is" + "disabled.") + physical_device_id = device_ids[device_id] + device_id = int(physical_device_id) + return torch.npu.get_device_name(device_id) @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d12e72c0..3a6a3563 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -21,6 +21,8 @@ logger = init_logger(__name__) +ASCEND_RT_VISIBLE_DEVICES = "ASCEND_RT_VISIBLE_DEVICES" + def try_register_lib(lib_name: str, lib_info: str = ""): import importlib diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker.py index c5884e36..e99eefd6 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker.py @@ -34,7 +34,6 @@ from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) @@ -47,6 +46,7 @@ WorkerInput) from vllm_ascend.model_runner import NPUModelRunner +from vllm_ascend.platform import NPUPlatform from vllm_ascend.utils import try_register_lib logger = init_logger(__name__) @@ -68,13 +68,14 @@ def __init__( is_driver_worker: bool = False, model_runner_cls: Optional[Type[ModelRunnerBase]] = None, ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) + # Try to import mindie_turbo to accelerate vLLM inference. try_register_lib( "mindie_turbo", "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." ) + # distribute related config self.parallel_config.rank = rank self.local_rank = local_rank @@ -101,19 +102,20 @@ def __init__( not in ["medusa", "mlp_speculator", "eagle"]) \ else {"return_hidden_states": True} - ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner - if model_config.runner_type == "pooling": - ModelRunnerClass = PoolingModelRunner - elif self.model_config.is_encoder_decoder: - ModelRunnerClass = EncoderDecoderModelRunner - self.model_runner: ModelRunnerBase = ModelRunnerClass( - vllm_config=self.vllm_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - **speculative_args, - ) if model_runner_cls is not None: self.model_runner = model_runner_cls(self.model_runner) + else: + ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner + if model_config.runner_type == "pooling": + ModelRunnerClass = PoolingModelRunner + elif self.model_config.is_encoder_decoder: + ModelRunnerClass = EncoderDecoderModelRunner + self.model_runner: ModelRunnerBase = ModelRunnerClass( + vllm_config=self.vllm_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + **speculative_args, + ) # Uninitialized cache engine. Will be initialized by # initialize_cache. @@ -170,17 +172,21 @@ def init_device(self) -> None: # # This env var set by Ray causes exceptions with graph building. # os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"npu:{self.local_rank}") - current_platform.set_device(self.device) - - current_platform.empty_cache() - self.init_npu_memory = current_platform.mem_get_info()[0] + NPUPlatform.set_device(self.device) + NPUPlatform.empty_cache() + self.init_npu_memory = NPUPlatform.mem_get_info()[0] else: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) + set_custom_all_reduce( + not self.parallel_config.disable_custom_all_reduce) + init_distributed_environment( + self.parallel_config.world_size, self.rank, + self.distributed_init_method, self.local_rank, "hccl") + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) # Set random seed. set_random_seed(self.model_config.seed) @@ -206,7 +212,7 @@ def save_tensorized_model( self.model_runner.save_tensorized_model( tensorizer_config=tensorizer_config, ) - @current_platform.inference_mode() + @NPUPlatform.inference_mode() def determine_num_available_blocks(self) -> Tuple[int, int]: """Profiles the peak memory usage of the model to determine how many KV blocks may be allocated without OOMs. @@ -219,7 +225,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. - current_platform.empty_cache() + NPUPlatform.empty_cache() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -227,7 +233,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Calculate the number of blocks that can be allocated with the # profiled peak memory. - free_npu_memory, total_npu_memory = current_platform.mem_get_info() + free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info() # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_npu_memory - free_npu_memory @@ -248,17 +254,36 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: gc.collect() # TODO: don`t need impl this func after empty_cache in # Worker.determine_num_available_blocks() unified` - current_platform.empty_cache() - return num_npu_blocks, num_cpu_blocks + NPUPlatform.empty_cache() + return num_npu_blocks, num_cpu_block + + def _raise_if_cache_size_invalid(self, num_gpu_blocks, block_size, + is_attention_free, max_model_len) -> None: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks}" + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if not is_attention_free and max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: """Allocate NPU and CPU KV cache with the specified number of blocks. """ - raise_if_cache_size_invalid(num_gpu_blocks, - self.cache_config.block_size, - self.cache_config.is_attention_free, - self.model_config.max_model_len) + self._raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.cache_config.is_attention_free, + self.model_config.max_model_len) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -447,39 +472,3 @@ def get_cache_block_size_bytes(self) -> int: return CacheEngine.get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) - - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, - backend: str = "hccl") -> None: - """Initialize the distributed environment.""" - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) - - init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank, backend) - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, - max_model_len) -> None: - if is_attention_free and num_gpu_blocks != 0: - raise ValueError("No memory should be allocated for the cache blocks " - f"for an attention-free model, but {num_gpu_blocks}" - "blocks are allocated.") - if not is_attention_free and num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * num_gpu_blocks - if not is_attention_free and max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.")