From d06be4560e4a3fcf499f930ec56f6ba7aca76935 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Fri, 21 Feb 2025 10:00:54 +0800 Subject: [PATCH] [Core] Cherry pick from 0.7.1 to keep the main code newest Signed-off-by: wangxiyuan --- .github/workflows/vllm_ascend_test.yaml | 2 +- vllm_ascend/attention.py | 833 ++++++++++++++--------- vllm_ascend/model_runner.py | 63 +- vllm_ascend/ops/__init__.py | 4 +- vllm_ascend/ops/activation.py | 29 + vllm_ascend/ops/fused_moe.py | 176 +++++ vllm_ascend/ops/rotary_embedding.py | 56 ++ vllm_ascend/platform.py | 21 +- vllm_ascend/quantization/__init__.py | 0 vllm_ascend/quantization/quant_config.py | 256 +++++++ vllm_ascend/quantization/quantizer.py | 51 ++ 11 files changed, 1137 insertions(+), 354 deletions(-) create mode 100644 vllm_ascend/ops/activation.py create mode 100644 vllm_ascend/ops/fused_moe.py create mode 100644 vllm_ascend/ops/rotary_embedding.py create mode 100644 vllm_ascend/quantization/__init__.py create mode 100644 vllm_ascend/quantization/quant_config.py create mode 100644 vllm_ascend/quantization/quantizer.py diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index b39a9310..de120be0 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -102,7 +102,7 @@ jobs: run: | pip install -e . - - name: Install torch-npu + - name: Install pta run: | mkdir pta cd pta diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 2f9b5e70..3088efb5 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -1,8 +1,6 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/vllm/attention/backends -# Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,10 +15,10 @@ # limitations under the License. # -import math from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +import numpy as np import torch try: @@ -30,19 +28,72 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + AttentionMetadata, AttentionType, + MLAAttentionImpl) +from vllm.attention.backends.utils import (CommonAttentionState, CommonMetadataBuilder, + compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm_ascend.model_runner import ModelInputForNPUBuilder -SHARE_MASK_TRIL_PREFIX_CACHE = None -SHARE_MASK_TRIL = None + +def generate_attn_mask(max_seq_len: int, dtype=torch.float16): + # Construct lower triangle matrix. + mask_flag = torch.tril( + torch.ones((max_seq_len, max_seq_len), + dtype=torch.bool)).view(max_seq_len, max_seq_len) + # Create upper triangle matrix used to mark mask positions. + mask_flag = ~mask_flag + # Currently for fp16 dtype, the mask value should be set to -inf. + # TODO: Eliminate this part in the future. + if dtype == torch.float16: + mask_value = torch.finfo(torch.float32).min + else: + mask_value = 1 + attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)), + mask_flag, mask_value).to(dtype) + return attn_mask + + +class AttentionMaskBuilder: + + def __init__(self, attn_mask: torch.Tensor): + self._seq_len_cached = attn_mask.shape[0] + self.attn_mask_cache = attn_mask + + @classmethod + def initialize_from_len(cls, + max_seq_len: int, + dtype: torch.dtype = torch.float16): + return cls(generate_attn_mask(max_seq_len, dtype)) + + def update_attn_cache(self, seqlen: int, dtype: torch.dtype, + device: torch.device): + if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype: + self._seq_len_cached = seqlen + self.attn_mask_cache = generate_attn_mask(seqlen, dtype) + if self.attn_mask_cache.device != device: + self.attn_mask_cache = self.attn_mask_cache.to(device) + + def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype, + device: torch.device): + self.update_attn_cache(max_seq_len, dtype, device) + return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous() + + def get_decode_attn_mask( + self, + input_lengths: torch.tensor, + max_s: int, + dtype: torch.dtype, + device: torch.device, + ): + self.update_attn_cache(max_s, dtype, device) + return (self.attn_mask_cache.index_select( + 0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous()) class AscendAttentionBackend(AttentionBackend): @@ -111,22 +162,24 @@ def make_metadata_builder(cls, *args, **kwargs) -> "AscendMetadataBuilder": return cls.get_builder_cls()(*args, **kwargs) -class AscendPagedAttention(PagedAttention): +class AscendMLAAttentionBackend(AscendAttentionBackend): @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_indices: torch.Tensor, - ) -> None: - torch_npu.npu_scatter_nd_update_(key_cache, slot_indices, key) - torch_npu.npu_scatter_nd_update_(value_cache, slot_indices, value) + def get_impl_cls() -> Type["AscendMLAAttentionBackendImpl"]: + return AscendMLAAttentionBackendImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, num_blocks, block_size, num_kv_heads * head_size) @dataclass -class AscendMetadata(AttentionMetadata, PagedAttentionMetadata): +class AscendMetadata(AttentionMetadata): """Metadata for Ascendbackend. * modified from XFormersbackend NOTE: Any python object stored here is not updated when it is @@ -142,9 +195,6 @@ class AscendMetadata(AttentionMetadata, PagedAttentionMetadata): # |-------------------- seq_len ----------------------| # |-- query_len ---| - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - # FIXME: It is for flash attn. # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. @@ -153,33 +203,17 @@ class AscendMetadata(AttentionMetadata, PagedAttentionMetadata): # requests only. max_decode_seq_len: int - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + block_tables: Optional[torch.Tensor] # (batch_size,). The sequence length per sequence. Sequence length means # the computed tokens + new tokens None if it is a decoding. seq_lens: Optional[List[int]] = None - # FIXME: It is for flash attn. - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] = None - - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] = None - # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] = None - # Self-attention prefill/decode metadata cache _cached_prefill_metadata: Optional["AscendMetadata"] = None _cached_decode_metadata: Optional["AscendMetadata"] = None @@ -197,16 +231,12 @@ class AscendMetadata(AttentionMetadata, PagedAttentionMetadata): num_encoder_tokens: Optional[int] = None attn_mask: Optional[torch.Tensor] = None - pse_shift: Optional[torch.Tensor] = None - sparse_mode: int = 0 # Cross-attention memory-mapping data structures: slot mapping # and block tables cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None - # slot_mapping: Optional[torch.Tensor] = None - @property def prefill_metadata(self) -> Optional["AscendMetadata"]: if self.num_prefills == 0: @@ -214,46 +244,31 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]: if self._cached_prefill_metadata is not None: # Recover cached prefill-phase attention - # metadata structure + # metadata structure. return self._cached_prefill_metadata assert ((self.seq_lens is not None) or (self.encoder_seq_lens is not None)) - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - # Compute some attn_metadata fields which default to None - query_start_loc = (None if self.query_start_loc is None else - self.query_start_loc[:self.num_prefills + 1]) + # Compute some attn_metadata fields which default to None. slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[:self.num_prefill_tokens]) seq_lens = (None if self.seq_lens is None else self.seq_lens[:self.num_prefills]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[:self.num_prefills]) - seq_start_loc = (None if self.seq_start_loc is None else - self.seq_start_loc[:self.num_prefills + 1]) - context_lens_tensor = (None if self.context_lens_tensor is None else - self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - # Construct & cache prefill-phase attention metadata structure + # Construct & cache prefill-phase attention metadata structure. self._cached_prefill_metadata = AscendMetadata( num_prefills=self.num_prefills, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=query_start_loc, - seq_start_loc=seq_start_loc, - context_lens_tensor=context_lens_tensor, block_tables=block_tables, - use_cuda_graph=False, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -272,39 +287,27 @@ def decode_metadata(self) -> Optional["AscendMetadata"]: if self._cached_decode_metadata is not None: # Recover cached decode-phase attention - # metadata structure + # metadata structure. return self._cached_decode_metadata - assert ((self.seq_lens_tensor is not None) - or (self.encoder_seq_lens_tensor is not None)) - # Compute some attn_metadata fields which default to None + # Compute some attn_metadata fields which default to None. slot_mapping = (None if self.slot_mapping is None else self.slot_mapping[self.num_prefill_tokens:]) - seq_lens_tensor = (None if self.seq_lens_tensor is None else - self.seq_lens_tensor[self.num_prefills:]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - # Construct & cache decode-phase attention metadata structure + # Construct & cache decode-phase attention metadata structure. self._cached_decode_metadata = AscendMetadata( num_prefills=0, num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, - seq_lens_tensor=seq_lens_tensor, + seq_lens=seq_lens, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, - # Batch may be composed of prefill|decodes, adjust query start - # indices to refer to the start of decodes. E.g. - # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - query_start_loc=(self.query_start_loc[self.num_prefills:] - - self.query_start_loc[self.num_prefills]) - if self.query_start_loc is not None else None, - seq_start_loc=self.seq_start_loc[self.num_prefills:] - if self.seq_start_loc is not None else None, - context_lens_tensor=None, block_tables=block_tables, - use_cuda_graph=self.use_cuda_graph, # Begin encoder & cross attn fields below... encoder_seq_lens=self.encoder_seq_lens, encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, @@ -320,6 +323,7 @@ def decode_metadata(self) -> Optional["AscendMetadata"]: class AscendMetadataBuilder(CommonMetadataBuilder[AscendMetadata]): _metadata_cls = AscendMetadata + _attn_mask_builder = None # noqa def __init__(self, input_builder: "ModelInputForNPUBuilder"): self.input_builder = input_builder @@ -327,42 +331,10 @@ def __init__(self, input_builder: "ModelInputForNPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - def compute_npu_slot_indices(self, is_profile_run, slot_indices, seq_id, - seq_len, context_len, start_idx, block_size, - block_tables, max_query_len): - """ - compute slot indices - slot mapping in other backend of vllm stores slot indices, - which are indicates by `block_number * block_size + block_offset` - In Ascend backend, slot mapping stores [block_number, block_offset]. - To distinguish this, slot_indices is used in this func - """ - if is_profile_run: - # During memory profiling, the block tables are not - # initialized yet. In this case, we just use a dummy - # slot mapping. - # In embeddings, the block tables are {seq_id: None}. - slot_indices.extend([[PAD_SLOT_ID, 0]] * seq_len) - return - # Mask the [0, start_idx) tokens of the prompt with - # [PAD_SLOT_ID, 0], where start_idx is max(0, seq_len - - # sliding_window). For example, if the prompt len is 10, - # sliding window is 8, and block size is 4, the first two - # tokens are masked and the slot mapping will be - # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - padding_mask_len = max(0, start_idx - context_len) - slot_indices.extend([[PAD_SLOT_ID, 0]] * padding_mask_len) - - range_start = max(start_idx, context_len) - range_end = seq_len - numel = range_end - range_start - block_table = block_tables[seq_id] - - for i in range(range_start, range_end): - block_number = block_table[i // block_size] - block_offset = i % block_size - slot_indices.append([block_number, block_offset]) - slot_indices.extend([[PAD_SLOT_ID, 0]] * (max_query_len - numel)) + self.attn_mask = None + if AscendMetadataBuilder._attn_mask_builder is None: + AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len( + 128, self.input_builder.runner.model_config.dtype) def _add_seq_group( self, inter_data: "ModelInputForNPUBuilder.InterDataForSeqGroup", @@ -374,12 +346,6 @@ def _add_seq_group( """ is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables - max_query_len = max( - max(data.query_lens) - for data in self.input_builder.inter_data_list) - - is_prompt = inter_data.is_prompt - block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( @@ -427,12 +393,78 @@ def _add_seq_group( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, context_len, self.sliding_window) + compute_slot_mapping( + is_profile_run, + self.slot_mapping, + seq_id, + seq_len, + context_len, + start_idx, + self.block_size, + inter_data.block_tables, + ) + + def build( + self, + seq_lens: List[int], + query_lens: List[int], + ): + """Build attention metadata with on-device tensors. - self.compute_npu_slot_indices(is_profile_run, self.slot_mapping, - seq_id, seq_len, context_len, - start_idx, self.block_size, - inter_data.block_tables, - max_query_len) + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + + if self.num_prefills > 0: + self.attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore + max_prefill_seq_len, + self.input_builder.runner.model_config.dtype, + self.input_builder.runner.device) + else: + self.attn_mask = None + + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int32, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + assert device is not None + + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=self.num_decode_tokens, + seq_lens=seq_lens, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + block_tables=block_tables, + attn_mask=self.attn_mask, + ) class AscendAttentionBackendImpl(AttentionImpl): @@ -454,15 +486,19 @@ def __init__( self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.hidden_size = self.num_heads * self.head_size self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + alibi_slopes = torch.tensor(alibi_slopes, + dtype=torch.float32, + device="npu") self.alibi_slopes = alibi_slopes self.attn_type = attn_type assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.seq_len_cpu_tensor = None def forward( self, @@ -470,7 +506,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: List[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: AscendMetadata, attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, @@ -491,195 +527,380 @@ def forward( Returns: shape = [batch_size, seq_len * num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 attn_type = self.attn_type if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " "PallasAttentionBackendImpl") - # view q k v to BSH + # View q k v to BSH. num_tokens = query.shape[0] + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + # TODO: Remove this contiguous in the future. + value = value.contiguous() + + output = torch.empty(num_tokens, + self.num_heads, + self.head_size, + dtype=query.dtype, + device=query.device) + + if hasattr(layer, 'quant_method'): + isPrefill = True if attn_metadata.num_prefills > 0 else False + if isPrefill: + assert attn_metadata.prefill_metadata is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.prefill_metadata.seq_lens).astype( + np.int32)) + else: + assert attn_metadata.decode_metadata is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.decode_metadata.seq_lens).astype( + np.int32)) + block_tables = attn_metadata.decode_metadata.block_tables if attn_metadata.decode_metadata else None + # Details of kv_cache arrangement in attention quantization + # are implemented by quant_method. + layer.quant_method.apply(layer, query, key, value, kv_cache, + self.scale, self.seq_lens_tensor_cpu, + block_tables, isPrefill, attn_metadata, + output) + else: + if kv_cache.numel() > 0: + key_cache, value_cache = kv_cache[0], kv_cache[1] + num_blocks, block_size, _ = key_cache.shape + key_cache = key_cache.view(num_blocks, block_size, + self.num_kv_heads, self.head_size) + value_cache = value_cache.view(num_blocks, block_size, + self.num_kv_heads, + self.head_size) + slots = attn_metadata.slot_mapping + torch_npu.npu_reshapecache(key=key, + value=value, + keyCache=key_cache, + valueCache=value_cache, + slotMapping=slots, + compressType=0, + kvCacheCfg=0) + + if attn_metadata.num_prefills > 0: + + if (attn_metadata.block_tables is None + or attn_metadata.block_tables.numel() == 0): + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + assert attn_metadata.prefill_metadata is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array( + attn_metadata.prefill_metadata.seq_lens).astype( + np.int32)) + torch_npu.npu_selfattention( + query=query, + key=key, + value=value, + mask=mask, + maskType=1, + isTriuMask=0, + seqLen=self.seq_lens_tensor_cpu, + scale=self.scale, + qScale=1, + headNum=self.num_heads, + kvHeadNum=self.num_kv_heads, + mlaVHeadSize=0, + calcType=3, + kernelType=0, + clampType=0, + scaleType=0, + quantType=0, + cacheType=0, + batchRunStatusEnable=False, + kvcacheCfg=0, + clampMin=0, + clampMax=0, + inputLayout=0, + windowSize=0, + outDataType=0, + out=output) + else: + # TODO: Will support prefix cache and chunked prefill soon. + raise RuntimeError( + "Prefix cache and chunked prefill are currently not supported." + ) + elif attn_metadata.decode_metadata: + assert kv_cache is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.decode_metadata.seq_lens).astype( + np.int32)) + block_tables = attn_metadata.decode_metadata.block_tables + torch_npu.npu_pagedattention( + query=query, + keyCache=key_cache, + valueCache=value_cache, + contextLens=self.seq_lens_tensor_cpu, + maskType=0, + kvHeadNum=self.num_kv_heads, + headNum=self.num_heads, + mlaVHeadSize=0, + qkScale=self.scale, + scaleType=0, + blockTables=block_tables, + batchRunStatusEnable=False, + hasQuantOffset=False, + calcType=3, + quantType=0, + compressType=0, + inputLayout=0, + outDataType=0, + attnOut=output) + + return output.view(num_tokens, self.hidden_size) + + +class AscendMLAAttentionBackendImpl(MLAAttentionImpl): - if kv_cache is not None and len(kv_cache) >= 2: - slot_indices = attn_metadata.slot_mapping - key_cache, value_cache = kv_cache[0], kv_cache[1] - AscendPagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - slot_indices, - ) + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + **extra_impl_args, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.hidden_size = self.num_heads * self.head_size + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, + dtype=torch.float32, + device="npu") + self.alibi_slopes = alibi_slopes + self.attn_type = attn_type + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.seq_len_cpu_tensor = None + + # MLA Args + self.q_lora_rank = extra_impl_args['q_lora_rank'] + self.kv_lora_rank = extra_impl_args['kv_lora_rank'] + self.qk_nope_head_dim = extra_impl_args['qk_nope_head_dim'] + self.qk_rope_head_dim = extra_impl_args['qk_rope_head_dim'] + self.qk_head_dim = extra_impl_args['qk_head_dim'] + self.v_head_dim = extra_impl_args['v_head_dim'] + self.rotary_emb = extra_impl_args['rotary_emb'] + self.q_proj = extra_impl_args['q_proj'] + self.kv_b_proj = extra_impl_args['kv_b_proj'] + self.o_proj = extra_impl_args['o_proj'] + self.w_kc = None + self.w_vc = None + def forward( + self, + layer: AttentionLayer, + hidden_states_or_q_c: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AscendMetadata, + attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with Ascend attention. + Args: + hidden_states_or_q_c: shape = [num_tokens, num_heads * head_size] + num_tokens = batch_size * seq_len + kv_c_normed: shape = [num_tokens, num_kv_heads * head_size] + k_pe: shape = [num_tokens, num_kv_heads * head_size] + kv_cache: shape = [1, num_blocks, block_size, + num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len * num_heads * head_size] + """ + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + num_tokens = hidden_states_or_q_c.shape[0] + q = self.q_proj(hidden_states_or_q_c)[0].view(-1, self.num_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) if attn_metadata.num_prefills > 0: - if attn_metadata.attn_mask is None: - if num_tokens > 16384: - attn_metadata.sparse_mode = 2 - attention_mask = gen_input_mask( - attn_metadata.max_prefill_seq_len, self.sliding_window, - num_tokens, query.device) - attn_metadata.attn_mask = attention_mask - - if (self.alibi_slopes is not None - and attn_metadata.pse_shift is None): - attn_metadata.pse_shift = _make_alibi_bias( - self.alibi_slopes, - self.num_kv_heads, - dtype=query.dtype, - seq_len=attn_metadata.max_prefill_seq_len, - batch_size=num_tokens, - device=query.device, - ) + assert attn_metadata.prefill_metadata is not None + assert attn_metadata.prefill_metadata.seq_lens is not None + np_positions = np.concatenate([ + np.arange(i) for i in attn_metadata.prefill_metadata.seq_lens + ]) + positions = torch.tensor(np_positions, + device=hidden_states_or_q_c.device) + else: + assert attn_metadata.decode_metadata is not None + np_positions = np.array(attn_metadata.decode_metadata.seq_lens) - 1 + positions = torch.tensor(np_positions, + device=hidden_states_or_q_c.device) + k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1) + + if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': + ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape + q_pe = q_pe.reshape(num_tokens, -1) + k_pe = k_pe.reshape(num_tokens, -1) + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_pe = q_pe.view(ori_q_pe_shape) + k_pe = k_pe.view(ori_k_pe_shape) + else: + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + if self.w_kc is None or self.w_vc is None: + kv_b_proj_weight = self.kv_b_proj.weight.reshape( + self.num_heads, self.qk_nope_head_dim + self.v_head_dim, + self.kv_lora_rank) + self.w_kc = kv_b_proj_weight[:, :self. + qk_nope_head_dim, :].contiguous() + self.w_vc = kv_b_proj_weight[:, + self.qk_nope_head_dim:, :].transpose( + 1, 2).contiguous() - if (len(kv_cache) == 0 or attn_metadata.block_tables is None + if attn_metadata.num_prefills > 0: + kv_heads_num = self.num_heads + kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num, + -1) + k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim], + dim=-1) + k_cache = torch.cat( + [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], + dim=2) + k_pe = k_pe.repeat(1, self.num_heads, 1) + key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe], + dim=2) + else: + kv_heads_num = self.num_kv_heads + q_nope_t = torch_npu.npu_transpose(q_nope, (1, 0, 2), + require_contiguous=True) + q_nope_out = torch.bmm(q_nope_t, self.w_kc) + q_nope = torch_npu.npu_transpose(q_nope_out, (1, 0, 2), + require_contiguous=True) + k_cache = torch.cat( + [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], + dim=2) + + query = torch.cat([q_nope, q_pe], dim=-1).view(num_tokens, + self.num_heads, -1) + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + num_blocks, block_size, _ = key_cache.shape + + key_cache = key_cache.view( + num_blocks, block_size, self.num_kv_heads, + self.qk_rope_head_dim + self.kv_lora_rank) + slots = attn_metadata.slot_mapping + torch_npu.npu_reshapecache(key=k_cache, + value=None, + keyCache=key_cache, + valueCache=None, + slotMapping=slots, + compressType=0, + kvCacheCfg=1) + + if attn_metadata.num_prefills > 0: + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=query.dtype, + device="npu") + if (attn_metadata.block_tables is None or attn_metadata.block_tables.numel() == 0): - max_seq_len = attn_metadata.max_prefill_seq_len - - # shape of q/k/v [B,S*H] --> [B,S,N,D] - query = query.view(-1, max_seq_len, self.num_heads, - self.head_size).transpose(1, 2) - key = key.view(-1, max_seq_len, self.num_kv_heads, - self.head_size).transpose(1, 2) - value = value.view(-1, max_seq_len, self.num_kv_heads, - self.head_size).transpose(1, 2) - # FA for prefill phase - output = torch_npu.npu_prompt_flash_attention( - query, - key, - value, - pse_shift=attn_metadata.pse_shift, - atten_mask=attn_metadata.attn_mask, - num_heads=self.num_heads, - scale_value=1 / math.sqrt(self.head_size), - input_layout="BNSD", - num_key_value_heads=self.num_kv_heads, - pre_tokens=65535, - next_tokens=0, - sparse_mode=attn_metadata.sparse_mode, - ) - # reshape to [B,H] - output = output.transpose(1, 2).reshape( - num_tokens, self.num_heads * self.head_size) + assert attn_metadata.attn_mask is not None + mask = attn_metadata.attn_mask + assert attn_metadata.prefill_metadata is not None + assert attn_metadata.prefill_metadata.seq_lens is not None + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.prefill_metadata.seq_lens).astype( + np.int32)) + torch_npu.npu_selfattention(query=query, + key=key, + value=value, + kvcacheCfg=0, + mask=mask, + maskType=1, + isTriuMask=0, + seqLen=self.seq_lens_tensor_cpu, + scale=self.scale, + qScale=1, + scaleType=0, + headNum=self.num_heads, + kvHeadNum=self.num_heads, + mlaVHeadSize=0, + calcType=3, + kernelType=0, + clampType=0, + quantType=0, + cacheType=0, + windowSize=0, + clampMin=0, + clampMax=0, + batchRunStatusEnable=False, + inputLayout=0, + outDataType=0, + out=attn_output) else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert attn_metadata.seq_lens is not None - assert kv_cache is not None - query = query.view(query.shape[0], -1, - self.num_heads * self.head_size) - output = torch.zeros(query.shape, - device=query.device, - dtype=query.dtype) - # TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention - # support only when `S == 1`, OPTIMIZE ME when prefix caching - # is supported in torch-npu ops. - for i in range(query.shape[0]): - # FA for prefill phase - output[i] = torch_npu.npu_incre_flash_attention( - query[i].unsqueeze(0), - key_cache, - value_cache, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - scale_value=self.scale, - input_layout="BSH", - block_table=attn_metadata.block_tables, - block_size=key_cache. - shape[1], # max val of block_size == 512 - actual_seq_lengths=attn_metadata.seq_lens, - ) - # [B,S,H] --> [B,H] - output = output.squeeze(1) - + # TODO: Will support prefix cache and chunked prefill soon. + raise RuntimeError( + "Prefix cache and chunked prefill are currently not supported." + ) elif attn_metadata.decode_metadata: - # FA for decoding phase assert kv_cache is not None - # shape of query [B,S*H] --> [B,S,H] - query = query.view( - -1, - 1, - self.head_size * self.num_heads, - ) - output = torch_npu.npu_incre_flash_attention( - query, - key_cache, - value_cache, - num_heads=self.num_heads, - num_key_value_heads=self.num_kv_heads, - scale_value=self.scale, - input_layout="BSH", - block_table=attn_metadata.block_tables, - block_size=key_cache.shape[1], # max val of block_size == 512 - actual_seq_lengths=attn_metadata.seq_lens, - ) + attn_output = torch.empty(num_tokens, + self.num_heads, + self.kv_lora_rank, + dtype=query.dtype, + device="npu") + self.seq_lens_tensor_cpu = torch.from_numpy( + np.array(attn_metadata.decode_metadata.seq_lens).astype( + np.int32)) + block_tables = attn_metadata.decode_metadata.block_tables + torch_npu.npu_pagedattention(query=query, + keyCache=key_cache, + valueCache=None, + contextLens=self.seq_lens_tensor_cpu, + maskType=0, + kvHeadNum=self.num_kv_heads, + headNum=self.num_heads, + mlaVHeadSize=self.kv_lora_rank, + qkScale=self.scale, + blockTables=block_tables, + batchRunStatusEnable=False, + hasQuantOffset=False, + compressType=0, + calcType=0, + scaleType=0, + quantType=0, + inputLayout=0, + outDataType=-1, + attnOut=attn_output) + attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2), + require_contiguous=True) + attn_output_t = torch.bmm(attn_output_t, self.w_vc) + attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2), + require_contiguous=True) + + output, _ = self.o_proj(attn_output.view(num_tokens, -1)) - # [B,S,H] --> [B,H] - output = output.squeeze(1) return output - - -def gen_input_mask(seq_len, sliding_window, len, device): - """ - Generating lower triangular matrix - """ - if len > 16384: - # improve computing performance on NPU when input tokens are huge - global SHARE_MASK_TRIL_PREFIX_CACHE - if SHARE_MASK_TRIL_PREFIX_CACHE is None: - SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu( - torch.ones(1, 1, 2048, 2048, dtype=bool, device=device), - diagonal=1, - ) - attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE - else: - global SHARE_MASK_TRIL - if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len: - SHARE_MASK_TRIL = ~torch.tril( - torch.ones(seq_len, seq_len, dtype=bool, device=device)) - - attention_mask = SHARE_MASK_TRIL - if sliding_window is not None: - attention_mask = ~attention_mask - attention_mask = torch.triu(attention_mask, - diagonal=1 - sliding_window) - attention_mask = ~attention_mask - - return attention_mask - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, - seq_len: int, - batch_size: int, - device: torch.device, -): - alibi_slopes = alibi_slopes.to(device) - bias = torch.arange(seq_len, dtype=dtype, device=device) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - # Calculate a matrix where each element represents ith element- jth - # element. - bias = bias[None, :] - bias[:, None] - - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, - num_heads, - seq_len, - padded_len, - device=device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - - return bias diff --git a/vllm_ascend/model_runner.py b/vllm_ascend/model_runner.py index 77e093b5..b43d2d19 100644 --- a/vllm_ascend/model_runner.py +++ b/vllm_ascend/model_runner.py @@ -53,7 +53,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, flatten_2d_lists, - is_pin_memory_available, make_tensor_with_pad) + is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -511,50 +511,21 @@ def build(self) -> ModelInputForNPU: for data in self.inter_data_list } - batch_size = len(input_tokens) - - if self.inter_data_list[0].is_prompt: - input_tokens_tensor = make_tensor_with_pad( - input_tokens, 0, dtype=torch.int, device=self.runner.device) - input_tokens_tensor = torch.flatten(input_tokens_tensor) - if mrope_input_positions is not None: - mrope_input_positions_tensor = make_tensor_with_pad( - mrope_input_positions, - 0, - dtype=torch.int, - device=self.runner.device) - input_positions_tensor = torch.tensor( - mrope_input_positions_tensor, - dtype=torch.long, - device=self.runner.device) - else: - input_positions_tensor = make_tensor_with_pad( - input_positions, - 0, - dtype=torch.int, - device=self.runner.device) - input_positions_tensor = torch.flatten(input_positions_tensor) - - max_seq_len = max(seq_lens) - seq_lens = len(seq_lens) * [max_seq_len] + input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens), + dtype=torch.long, + device=self.runner.device) + if mrope_input_positions is not None: + input_positions_tensor = torch.tensor(mrope_input_positions, + dtype=torch.long, + device=self.runner.device) else: - input_tokens_tensor = torch.tensor(flatten_2d_lists(input_tokens), - dtype=torch.long, - device=self.runner.device) - if mrope_input_positions is not None: - input_positions_tensor = torch.tensor( - mrope_input_positions, - dtype=torch.long, - device=self.runner.device) - else: - input_positions_tensor = torch.tensor( - flatten_2d_lists(input_positions), - dtype=torch.long, - device=self.runner.device) + input_positions_tensor = torch.tensor( + flatten_2d_lists(input_positions), + dtype=torch.long, + device=self.runner.device) # Attention metadata. - attn_metadata = self.attn_metadata_builder.build( - seq_lens, query_lens, -1, batch_size) + attn_metadata = self.attn_metadata_builder.build(seq_lens, query_lens) # Multi-modal data. multi_modal_kwargs_list = [ @@ -749,10 +720,14 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mrope_input_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions( token_ids, - hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, - second_per_grid_ts=None, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, context_len=inter_data.context_lens[seq_idx], seq_len=inter_data.seq_lens[seq_idx], ) diff --git a/vllm_ascend/ops/__init__.py b/vllm_ascend/ops/__init__.py index bdc40cd5..2ae5f1a7 100644 --- a/vllm_ascend/ops/__init__.py +++ b/vllm_ascend/ops/__init__.py @@ -14,5 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import vllm_ascend.ops.activation # noqa +import vllm_ascend.ops.fused_moe # noqa import vllm_ascend.ops.layernorm # noqa +import vllm_ascend.ops.rotary_embedding # noqa diff --git a/vllm_ascend/ops/activation.py b/vllm_ascend/ops/activation.py new file mode 100644 index 00000000..03072456 --- /dev/null +++ b/vllm_ascend/ops/activation.py @@ -0,0 +1,29 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from vllm.model_executor.layers.activation import SiluAndMul + + +def silu_and_mul_forward_oot(self, x: torch.Tensor) -> torch.Tensor: + import torch_npu + + out = torch_npu.npu_swiglu(x) + return out + + +SiluAndMul.forward_oot = silu_and_mul_forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py new file mode 100644 index 00000000..cbb86224 --- /dev/null +++ b/vllm_ascend/ops/fused_moe.py @@ -0,0 +1,176 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Callable, Optional + +import torch +import torch_npu +from vllm.model_executor.layers.fused_moe.layer import \ + UnquantizedFusedMoEMethod + + +def group_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: Optional[int] = 0, + topk_group: Optional[int] = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + + torch_npu.npu_group_topk(input=scores, + out=scores, + group_num=num_expert_group, + k=topk_group) + if e_score_correction_bias is not None: + topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, top_k: int): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + ori_shape = hidden_states.shape + if len(ori_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + + row_idx_len = num_tokens * top_k + row_idx = torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view(top_k, -1).permute( + 1, 0).contiguous() + expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + expanded_expert_idx, E) + expert_tokens = expert_tokens.to(torch.int64) + + w1 = w1.transpose(1, 2) + gate_up_out_list = torch_npu.npu_grouped_matmul(x=[expanded_x], + weight=[w1], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens) + + # TODO: Remove this in the future. + gate_up_out = torch.cat(gate_up_out_list, dim=0) + gate_up_out = torch_npu.npu_swiglu(gate_up_out) + + w2 = w2.transpose(1, 2) + down_out_list = torch_npu.npu_grouped_matmul(x=[gate_up_out], + weight=[w2], + split_item=2, + group_list_type=0, + group_type=0, + group_list=expert_tokens) + + down_out_list = torch.cat(down_out_list, dim=0) + # TODO: Reorder device memory 2 times here, replace the current + # implementation here when suitable operators become available. + routing_weights = topk_weights.to(down_out_list.dtype) + hidden_states = torch_npu.npu_moe_finalize_routing( + down_out_list, + skip1=None, + skip2=None, + bias=None, + scales=routing_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids) + if len(ori_shape) == 3: + hidden_states = hidden_states.view(ori_shape) + return hidden_states + + +def forward_oot( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + + topk_weights, topk_ids = group_topk( + hidden_states=x, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k) + + +UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py new file mode 100644 index 00000000..2279ad15 --- /dev/null +++ b/vllm_ascend/ops/rotary_embedding.py @@ -0,0 +1,56 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Tuple + +import torch +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + + +def rope_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + import torch_npu + + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + # TODO: Remove the contiguous in the future. + query = query.contiguous() + key = key.contiguous() + torch_npu.npu_rope( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + + return query, key + + +RotaryEmbedding.forward_oot = rope_forward_oot diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 1f22e564..9ad05c26 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -16,7 +16,7 @@ # import os -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch @@ -28,6 +28,11 @@ from vllm.config import VllmConfig from vllm.platforms import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = None + os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1" @@ -53,6 +58,15 @@ class NPUPlatform(Platform): ray_device_key: str = "NPU" device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" + supported_quantization: list[str] = ["ascend"] + + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + from vllm_ascend.quantization.quant_config import \ + AscendQuantConfig # noqa: F401 + @classmethod def get_device_capability(cls, device_id: int = 0): return None @@ -96,11 +110,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker" cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: - cache_config.block_size = 128 + # TODO: Set block_size to 128 will lead unexpected accuracy issue in mla case. Please set block_size to 128 back once the problem is fixed. + cache_config.block_size = 16 @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): + if use_mla: + return "vllm_ascend.attention.AscendMLAAttentionBackend" return "vllm_ascend.attention.AscendAttentionBackend" @classmethod diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py new file mode 100644 index 00000000..35d767be --- /dev/null +++ b/vllm_ascend/quantization/quant_config.py @@ -0,0 +1,256 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from types import MappingProxyType +from typing import Any, Dict, List, Mapping, Optional + +import torch +import torch_npu # noqa: F401 +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import \ + register_quantization_config +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter) + +from .quantizer import AscendQuantizer + +logger = init_logger(__name__) + + +@register_quantization_config("ascend") +class AscendQuantConfig(QuantizationConfig): + """Config class for Ascend""" + + def __init__(self, quant_config: Dict[str, Any]): + self.quant_description = quant_config + + def __repr__(self) -> str: + return "AscendQuantConfig:\n" + super().__repr__() + + @classmethod + def get_name(cls) -> str: + return "ascend" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "Ascend hardware dose not support \"get_min_capability\" feature.") + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": + return cls(config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + if torch.npu.is_available(): + return "ascend" + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention + if isinstance(layer, LinearBase): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedLinearMethod() + return AscendLinearMethod(self) + if isinstance(layer, Attention) and \ + 'fa_quant_type' in self.quant_description.keys(): + return AscendQKVQuantAttentionMethod(self) + return None + + def is_layer_skipped_ascend( + self, + prefix: str, + fused_mapping: Mapping[str, List[str]] = MappingProxyType({})): + # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped + proj_name = prefix.split(".")[-1] + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = self.quant_description[shard_prefix + + '.weight'] == "FLOAT" + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" + + assert is_skipped is not None + return is_skipped + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class AscendLinearMethod(LinearMethodBase): + """Linear method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig) -> None: + self.quantizer = AscendQuantizer.get_quantizer( + quant_config.quant_description) + self.quant_method = self.quantizer.build_linear_method() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weights = self.quant_method.create_weights(input_size_per_partition, + output_size_per_partition, + params_dtype) + + weight_name = self.quant_method.get_weight() + if weight_name in weights.keys(): + layer.register_parameter( + weight_name, + ModelWeightParameter(data=weights[weight_name].transpose(0, 1), + input_dim=1, + output_dim=0, + weight_loader=weight_loader)) + else: + raise ValueError( + f"{weight_name} is nor registered. Please check your linear quant method implementation." + ) + + pertensor_names = self.quant_method.get_pertensor_param() + for pertensor_name in pertensor_names: + if pertensor_name in weights.keys(): + param = BasevLLMParameter(data=weights[pertensor_name], + weight_loader=weight_loader) + # disable warning + param.ignore_warning = True + layer.register_parameter(pertensor_name, param) + else: + raise ValueError( + f"{pertensor_name} is nor registered. Please check your linear quant method implementation." + ) + + perchannel_names = self.quant_method.get_perchannel_param() + for perchannel_name in perchannel_names: + if perchannel_name in weights.keys(): + layer.register_parameter( + perchannel_name, + ChannelQuantScaleParameter(data=weights[perchannel_name], + output_dim=0, + weight_loader=weight_loader)) + else: + raise ValueError( + f"{perchannel_name} is nor registered. Please check your linear quant method implementation." + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, + 'transpose_weight') and self.quant_method.transpose_weight: + layer.weight.data = layer.weight.data.transpose(1, 0) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(layer, RowParallelLinear): + tp_rank = get_tensor_model_parallel_rank() + return self.quant_method.apply(layer, x, bias, tp_rank) + return self.quant_method.apply(layer, x, bias) + + +class AscendQKVQuantAttentionMethod(BaseKVCacheMethod): + """Linear method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig) -> None: + self.quantizer = AscendQuantizer.get_quantizer( + quant_config.quant_description) + self.quant_method = self.quantizer.build_attention_method() + + def create_weights(self, layer: torch.nn.Module) -> None: + # ascend attention quantization might include some extra weights + # and must be loaded by dummy modules + extra_module_names = self.quant_method.get_extra_module_names() + for name in extra_module_names: + setattr(layer, name, torch.nn.Module()) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + weights = self.quant_method.create_weights(dtype, layer.num_heads, + layer.num_kv_heads) + + for name, weight in weights.items(): + module_name, weight_name = name.split('.') + module = getattr(layer, module_name) + module.register_parameter( + weight_name, torch.nn.Parameter(weight, requires_grad=False)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, "process_weights_after_loading"): + self.quant_method.process_weights_after_loading(layer) + + def apply(self, layer: torch.nn.Module, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + kv_cache: List[torch.Tensor], scale: torch.Tensor, + seq_lens_tensor_cpu: int, block_tables: torch.Tensor, + isPrefill: bool, attn_metadata, output) -> torch.Tensor: + return self.quant_method.apply(layer, query, key, value, kv_cache, + scale, seq_lens_tensor_cpu, + block_tables, isPrefill, attn_metadata, + output) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py new file mode 100644 index 00000000..f6cc450c --- /dev/null +++ b/vllm_ascend/quantization/quantizer.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +from typing import Any, Dict, List + +CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] + + +class AscendQuantizer: + """An interface to different quantization implementations for ascend hardwares.""" + + @classmethod + def get_quantizer(cls, quant_config: Dict[str, Any]): + # TODO: Need a param to choose quantization algorithms. + quantization_algorithm = '' + + if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE: + return + + try: + module = importlib.import_module("mindie_turbo") + MindIETurboQuantizer = module.MindIETurboQuantizer + except Exception: + raise NotImplementedError( + "There is no available ascend quantizer.") + + return MindIETurboQuantizer.get_quantizer(quant_config) + + def build_linear_method(self): + raise NotImplementedError + + def build_moe_method(self): + raise NotImplementedError + + def build_attention_method(self): + raise NotImplementedError