-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix]add int8 cache dtype when using attention quantization #128
Open
Angazenn
wants to merge
6
commits into
vllm-project:main
Choose a base branch
from
Angazenn:bug_fix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# This file is a part of the vllm-ascend project. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from vllm_ascend.patch import patch_attention, patch_cache_dtype # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# This file is a part of the vllm-ascend project. | ||
# Adapted from vllm/vllm/attention/layer.py | ||
# Copyright 2023 The vLLM team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# This file is used to monkey patch vLLM Attention.__init__ function | ||
# and move the instantiation of num_heads, head_size, num_kv_heads | ||
# ahead of the initialization of attention quant methods, which is | ||
# required by ascend attention quant method to initialize. | ||
# Remove this file when vllm support it. | ||
|
||
from typing import Any, Dict, List, Optional | ||
|
||
import torch | ||
|
||
import vllm.envs as envs | ||
from vllm.attention import Attention, AttentionType | ||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend | ||
from vllm.config import CacheConfig, get_current_vllm_config | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod | ||
from vllm.platforms import current_platform | ||
|
||
|
||
def attention_init( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
cache_config: Optional[CacheConfig] = None, | ||
quant_config: Optional[QuantizationConfig] = None, | ||
blocksparse_params: Optional[Dict[str, Any]] = None, | ||
logits_soft_cap: Optional[float] = None, | ||
per_layer_sliding_window: Optional[int] = None, | ||
use_mla: bool = False, | ||
prefix: str = "", | ||
attn_type: str = AttentionType.DECODER, | ||
**extra_impl_args, | ||
) -> None: | ||
super(Attention, self).__init__() | ||
if per_layer_sliding_window is not None: | ||
# per-layer sliding window | ||
sliding_window = per_layer_sliding_window | ||
elif cache_config is not None: | ||
# model-level sliding window | ||
sliding_window = cache_config.sliding_window | ||
else: | ||
sliding_window = None | ||
|
||
if cache_config is not None: | ||
kv_cache_dtype = cache_config.cache_dtype | ||
block_size = cache_config.block_size | ||
is_attention_free = cache_config.is_attention_free | ||
calculate_kv_scales = cache_config.calculate_kv_scales | ||
else: | ||
kv_cache_dtype = "auto" | ||
block_size = 16 | ||
is_attention_free = False | ||
calculate_kv_scales = False | ||
if num_kv_heads is None: | ||
num_kv_heads = num_heads | ||
|
||
# The default k/v_scale is set to 1.0. This is ignored | ||
# when kv-cache is not fp8, and should be used with | ||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we | ||
# expect the pre-quantized k/v_scale to be loaded along | ||
# with the model weights. | ||
self.kv_cache_dtype = kv_cache_dtype | ||
self.calculate_kv_scales = calculate_kv_scales | ||
self._k_scale = torch.tensor(1.0, dtype=torch.float32) | ||
self._v_scale = torch.tensor(1.0, dtype=torch.float32) | ||
|
||
# We also keep the float32 versions of k/v_scale for attention | ||
# backends that don't support tensors (Flashinfer) | ||
self._k_scale_float = 1.0 | ||
self._v_scale_float = 1.0 | ||
|
||
# should move following three lines before quant method is instantiated. | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.num_kv_heads = num_kv_heads | ||
|
||
quant_method = quant_config.get_quant_method( | ||
self, prefix=prefix) if quant_config else None | ||
if quant_method is not None: | ||
assert isinstance(quant_method, BaseKVCacheMethod) | ||
# TODO (mgoin): kv cache dtype should be specified in the FP8 | ||
# checkpoint config and become the "auto" behavior | ||
if self.kv_cache_dtype == "fp8_e5m2": | ||
raise ValueError("fp8_e5m2 kv-cache is not supported with " | ||
"fp8 checkpoints.") | ||
# If quantization is enabled, we make "k_scale" and "v_scale" | ||
# parameters so that it can be loaded from the model checkpoint. | ||
# The k/v_scale will then be converted back to native float32 | ||
# values after weight loading. | ||
self.quant_method = quant_method | ||
self.quant_method.create_weights(self) | ||
|
||
# During model initialization, the default dtype is set as the model | ||
# weight and activation dtype. | ||
dtype = torch.get_default_dtype() | ||
attn_backend = get_attn_backend(head_size, | ||
dtype, | ||
kv_cache_dtype, | ||
block_size, | ||
is_attention_free, | ||
blocksparse_params is not None, | ||
use_mla=use_mla) | ||
impl_cls = attn_backend.get_impl_cls() | ||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, | ||
alibi_slopes, sliding_window, kv_cache_dtype, | ||
blocksparse_params, logits_soft_cap, attn_type, | ||
**extra_impl_args) | ||
self.sliding_window = sliding_window | ||
self.backend = backend_name_to_enum(attn_backend.get_name()) | ||
self.dtype = dtype | ||
|
||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how | ||
# torch.compile works by registering the attention as one giant | ||
# opaque custom op. For other platforms, we directly call them | ||
# and let torch.compile handle them. | ||
self.use_direct_call = not current_platform.is_cuda_alike( | ||
) and not current_platform.is_cpu() | ||
|
||
self.use_output = attn_backend.accept_output_buffer | ||
compilation_config = get_current_vllm_config().compilation_config | ||
if prefix in compilation_config.static_forward_context: | ||
raise ValueError(f"Duplicate layer name: {prefix}") | ||
compilation_config.static_forward_context[prefix] = self | ||
self.layer_name = prefix | ||
self.attn_type = attn_type | ||
# use a placeholder kv cache tensor during init, which will be replaced | ||
# by bind_kv_cache | ||
# this variable will not be accessed if use_direct_call is True | ||
self.kv_cache = [ | ||
torch.tensor([]) for _ in range( | ||
get_current_vllm_config().parallel_config.pipeline_parallel_size) | ||
] | ||
|
||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) | ||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) | ||
|
||
|
||
Attention.__init__ = attention_init |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# | ||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. | ||
# This file is a part of the vllm-ascend project. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# This file is used to monkey patch int8 cache dtype in vllm to support ascend. | ||
# Remove this file when vllm support int8 cache dtype. | ||
|
||
import torch | ||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE | ||
|
||
STR_DTYPE_TO_TORCH_DTYPE['int8'] = torch.int8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does vllm originally has entrypoints to resgister custom quant types into STR_DTYPE_TO_TORCH_DTYPE? If not, maybe we need to make this an issue to vllm? @wangxiyuan please check this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we can create a pr to vllm later.
Currenlty. for this kind of change(Monkey patch), please move it to the patch module.