From bafc5f576b04ee5d1094ce6996c304819e9bcbd5 Mon Sep 17 00:00:00 2001 From: angazenn Date: Wed, 5 Feb 2025 21:08:53 +0800 Subject: [PATCH 01/26] add ascend quantize Signed-off-by: angazenn --- vllm_ascend/ops/layernorm.py | 49 +++++--- vllm_ascend/platform.py | 4 + vllm_ascend/quantize/__init__.py | 0 vllm_ascend/quantize/quant_config.py | 171 +++++++++++++++++++++++++++ vllm_ascend/quantize/quantizer.py | 52 ++++++++ 5 files changed, 259 insertions(+), 17 deletions(-) create mode 100644 vllm_ascend/quantize/__init__.py create mode 100644 vllm_ascend/quantize/quant_config.py create mode 100644 vllm_ascend/quantize/quantizer.py diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 719aa977..d8beb95e 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -1,31 +1,22 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - from typing import Optional, Tuple, Union import torch from vllm.model_executor.layers.layernorm import RMSNorm +try: + from mindie_turbo import RMSNormWithAntiOutlier +except: + pass + def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if hasattr(self, "module"): + return self.module.forward_anti_outlier(x, residual) + import torch_npu if residual is not None: @@ -37,4 +28,28 @@ def forward_oot( return x +def enable_rmsnorm_with_antioutlier(): + def init( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + ) -> None: + super(RMSNorm, self).__init__() + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.variance_size_override = (None if var_hidden_size == hidden_size + else var_hidden_size) + self.has_weight = has_weight + + self.weight = torch.ones(hidden_size) + if self.has_weight: + self.weight = torch.nn.Parameter(self.weight) + + self.module = RMSNormWithAntiOutlier(self.hidden_size) + + RMSNorm.__init__ = init + + RMSNorm.forward_oot = forward_oot diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2b847de1..dfa39772 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -53,6 +53,10 @@ class NPUPlatform(Platform): ray_device_key: str = "NPU" device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES" + supported_quantization: list[str] = [ + "ascend" + ] + @classmethod def get_device_capability(cls, device_id: int = 0): return None diff --git a/vllm_ascend/quantize/__init__.py b/vllm_ascend/quantize/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/quantize/quant_config.py b/vllm_ascend/quantize/quant_config.py new file mode 100644 index 00000000..22230d3b --- /dev/null +++ b/vllm_ascend/quantize/quant_config.py @@ -0,0 +1,171 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, List, Optional + +import torch +import torch_npu + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear) +from vllm.model_executor.layers.quantization import (register_quantization_config) +from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + PackedvLLMParameter, + ModelWeightParameter) +from vllm.distributed import get_tensor_model_parallel_rank +from .quantizer import AscendQuantizer + +logger = init_logger(__name__) + +@register_quantization_config("ascend") +class AscendQuantConfig(QuantizationConfig): + """Config class for Ascend""" + + def __init__(self, quant_config: Dict[str, Any]): + self.quant_config = quant_config + self.quantizer = AscendQuantizer.get_quantizer(quant_config) + + def __repr__(self) -> str: + return "AscendQuantConfig:\n" + super().__repr__() + + @classmethod + def get_name(cls) -> str: + return "ascend" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.int8, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError("Ascend hardware dose not support \"get_min_capability\" feature.") + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": + return cls(config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + dev_type = hf_quant_cfg.get("dev_type", None) + if dev_type == "npu": + return "ascend" + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return AscendLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class AscendLinearMethod(LinearMethodBase): + """Linear method for Ascend quantization. + + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig) -> None: + self.quantizer = quant_config.quantizer + self.quant_method = self.quantizer.build_linear_method() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + weights = self.quant_method.create_weights( + layer, + input_size_per_partition, + output_size_per_partition, + params_dtype + ) + + weight_name = self.quant_method.get_weight() + if weight_name in weights.keys(): + layer.register_parameter( + weight_name, + ModelWeightParameter( + data=weights[weight_name].transpose(0, 1), + input_dim=1, + output_dim=0, + weight_loader=weight_loader + ) + ) + else: + raise ValueError(f"{weight_name} is nor registered. Please check your linear quant method implementation.") + + pertensor_names = self.quant_method.get_pertensor_param() + for pertensor_name in pertensor_names: + if pertensor_name in weights.keys(): + layer.register_parameter( + pertensor_name, + BasevLLMParameter( + data=weights[pertensor_name], + weight_loader=weight_loader + ) + ) + else: + raise ValueError(f"{pertensor_name} is nor registered. Please check your linear quant method implementation.") + + perchannel_names = self.quant_method.get_perchannel_param() + for perchannel_name in perchannel_names: + if perchannel_name in weights.keys(): + layer.register_parameter( + perchannel_name, + ChannelQuantScaleParameter( + data=weights[perchannel_name], + output_dim=0, + weight_loader=weight_loader + ) + ) + else: + raise ValueError(f"{perchannel_name} is nor registered. Please check your linear quant method implementation.") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if hasattr(self.quant_method, 'transpose_weight') and self.quant_method.transpose_weight: + layer.weight.data = layer.weight.data.transpose(1, 0) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(layer, RowParallelLinear): + tp_rank = get_tensor_model_parallel_rank() + return self.quant_method.apply(layer, x, bias, tp_rank) + return self.quant_method.apply(layer, x, bias) \ No newline at end of file diff --git a/vllm_ascend/quantize/quantizer.py b/vllm_ascend/quantize/quantizer.py new file mode 100644 index 00000000..3151f08a --- /dev/null +++ b/vllm_ascend/quantize/quantizer.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict + +from vllm_ascend.ops.layernorm import enable_rmsnorm_with_antioutlier + +CUSTOMIZED_QUANTIZER_TYPE = [] + +class AscendQuantizer: + """An iterface to different quantization implementations for ascend hardwares.""" + + @classmethod + def get_quantizer(cls, quant_config: Dict[str, Any]): + # TODO: Need a param to choose quantization algorithms. + quantization_algorithm = '' + + if quantization_algorithm in CUSTOMIZED_QUANTIZER_TYPE: + return + + try: + from mindie_turbo import MindIETurboQuantizer + + if quant_config["anti_method"] is not None: + enable_rmsnorm_with_antioutlier() + + return MindIETurboQuantizer.get_quantizer(quant_config) + except: + raise NotImplementedError("There is no available ascend quantizer.") + + def build_linear_method(self): + raise NotImplementedError + + def build_moe_method(self): + raise NotImplementedError + + def build_attention_method(self): + raise NotImplementedError \ No newline at end of file From a4aaea41dfd7b6c8347315eeda3698663940d466 Mon Sep 17 00:00:00 2001 From: angazenn Date: Wed, 5 Feb 2025 21:11:57 +0800 Subject: [PATCH 02/26] add license Signed-off-by: angazenn --- vllm_ascend/ops/layernorm.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index d8beb95e..975258b7 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -1,3 +1,20 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from typing import Optional, Tuple, Union import torch From 43b9f7076cc3da39e666571759c4afff0fdc75c4 Mon Sep 17 00:00:00 2001 From: angazenn Date: Thu, 6 Feb 2025 15:21:29 +0800 Subject: [PATCH 03/26] fix quantization import bugs Signed-off-by: angazenn --- vllm_ascend/platform.py | 16 +++++++++++++++- vllm_ascend/quantize/quantizer.py | 3 ++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index dfa39772..744ce135 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -16,7 +16,7 @@ # import os -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import torch @@ -27,6 +27,10 @@ from vllm.config import VllmConfig from vllm.platforms import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser +else: + FlexibleArgumentParser = None os.environ["RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES"] = "1" @@ -90,6 +94,16 @@ def synchronize(cls): def mem_get_info(cls) -> Tuple[int, int]: return torch.npu.mem_get_info() + # Relies on this pull request https://github.com/vllm-project/vllm/pull/12432. + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + """ + Do some pre-registeration or update action for ascend platform. + """ + from vllm_ascend.quantize.quant_config import AscendQuantConfig + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Register ops when setup. diff --git a/vllm_ascend/quantize/quantizer.py b/vllm_ascend/quantize/quantizer.py index 3151f08a..977b0b19 100644 --- a/vllm_ascend/quantize/quantizer.py +++ b/vllm_ascend/quantize/quantizer.py @@ -35,7 +35,8 @@ def get_quantizer(cls, quant_config: Dict[str, Any]): try: from mindie_turbo import MindIETurboQuantizer - if quant_config["anti_method"] is not None: + # When not using anti-outlier algorithms, "anti_method" refers to an empty string. + if len(quant_config["anti_method"]) > 0: enable_rmsnorm_with_antioutlier() return MindIETurboQuantizer.get_quantizer(quant_config) From 46b7ca242f47182f02e0a5705faa773839989faf Mon Sep 17 00:00:00 2001 From: angazenn Date: Thu, 6 Feb 2025 17:13:07 +0800 Subject: [PATCH 04/26] support skipping unquantized layers Signed-off-by: angazenn --- vllm_ascend/quantize/quant_config.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/vllm_ascend/quantize/quant_config.py b/vllm_ascend/quantize/quant_config.py index 22230d3b..122b6cb4 100644 --- a/vllm_ascend/quantize/quant_config.py +++ b/vllm_ascend/quantize/quant_config.py @@ -75,9 +75,37 @@ def override_quantization_method(cls, hf_quant_cfg, def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): + if self.is_layer_skipped_ascend(prefix): + return UnquantizedLinearMethod() return AscendLinearMethod(self) return None + def is_layer_skipped_ascend(self, prefix: str): + # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = self.quant_description[shard_prefix + '.weight'] == "FLOAT" + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = self.quant_description[prefix + '.weight'] == "FLOAT" + + assert is_skipped is not None + return is_skipped + def get_scaled_act_names(self) -> List[str]: return [] From 5b1b34db92dc6f092e59bc51cbbdbf9ef48e1ffb Mon Sep 17 00:00:00 2001 From: angazenn Date: Thu, 6 Feb 2025 17:19:08 +0800 Subject: [PATCH 05/26] fix ci Signed-off-by: angazenn --- vllm_ascend/quantize/quant_config.py | 3 +-- vllm_ascend/quantize/quantizer.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/quantize/quant_config.py b/vllm_ascend/quantize/quant_config.py index 122b6cb4..a69ef01d 100644 --- a/vllm_ascend/quantize/quant_config.py +++ b/vllm_ascend/quantize/quant_config.py @@ -23,10 +23,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear) from vllm.model_executor.layers.quantization import (register_quantization_config) -from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig) +from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, - PackedvLLMParameter, ModelWeightParameter) from vllm.distributed import get_tensor_model_parallel_rank from .quantizer import AscendQuantizer diff --git a/vllm_ascend/quantize/quantizer.py b/vllm_ascend/quantize/quantizer.py index 977b0b19..8c9ce60f 100644 --- a/vllm_ascend/quantize/quantizer.py +++ b/vllm_ascend/quantize/quantizer.py @@ -15,11 +15,11 @@ # limitations under the License. # -from typing import Any, Dict +from typing import Any, Dict, List from vllm_ascend.ops.layernorm import enable_rmsnorm_with_antioutlier -CUSTOMIZED_QUANTIZER_TYPE = [] +CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] class AscendQuantizer: """An iterface to different quantization implementations for ascend hardwares.""" From 7a41f8f0184c74598a61d5e135c40787ead5069e Mon Sep 17 00:00:00 2001 From: angazenn Date: Thu, 6 Feb 2025 19:11:57 +0800 Subject: [PATCH 06/26] remove unnecessary params Signed-off-by: angazenn --- vllm_ascend/quantize/quant_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/quantize/quant_config.py b/vllm_ascend/quantize/quant_config.py index a69ef01d..e12f5657 100644 --- a/vllm_ascend/quantize/quant_config.py +++ b/vllm_ascend/quantize/quant_config.py @@ -135,7 +135,6 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") weights = self.quant_method.create_weights( - layer, input_size_per_partition, output_size_per_partition, params_dtype From d3323515f4b069ea0a4e920285cf6e47d89d584f Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 10:01:23 +0800 Subject: [PATCH 07/26] fix import errors Signed-off-by: angazenn --- vllm_ascend/quantize/quant_config.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/quantize/quant_config.py b/vllm_ascend/quantize/quant_config.py index e12f5657..edb2e172 100644 --- a/vllm_ascend/quantize/quant_config.py +++ b/vllm_ascend/quantize/quant_config.py @@ -18,15 +18,17 @@ from typing import Any, Dict, List, Optional import torch -import torch_npu from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import (register_quantization_config) from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, ModelWeightParameter) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + FUSED_LAYER_NAME_MAPPING) from vllm.distributed import get_tensor_model_parallel_rank from .quantizer import AscendQuantizer From 37c4543d2dadb5a995cb8a92e714a488f0ae0c69 Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 15:44:30 +0800 Subject: [PATCH 08/26] avoid import check Signed-off-by: angazenn --- vllm_ascend/platform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 744ce135..42861d1e 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -102,7 +102,7 @@ def pre_register_and_update(cls, """ Do some pre-registeration or update action for ascend platform. """ - from vllm_ascend.quantize.quant_config import AscendQuantConfig + from vllm_ascend.quantize.quant_config import AscendQuantConfig # noqa: F401 @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: From 7e230f04e72f0fbe14bf7e01480db617ec802993 Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 17:18:34 +0800 Subject: [PATCH 09/26] add ascend quantization ut Signed-off-by: angazenn --- {vllm_ascend/quantize => tests}/__init__.py | 0 tests/quantization/__init__.py | 0 tests/quantization/test_mindie_turbo.py | 58 +++++++++++++++++++ tests/quantization/utils.py | 26 +++++++++ vllm_ascend/platform.py | 2 +- vllm_ascend/quantization/__init__.py | 0 .../quant_config.py | 0 .../{quantize => quantization}/quantizer.py | 0 8 files changed, 85 insertions(+), 1 deletion(-) rename {vllm_ascend/quantize => tests}/__init__.py (100%) create mode 100644 tests/quantization/__init__.py create mode 100644 tests/quantization/test_mindie_turbo.py create mode 100644 tests/quantization/utils.py create mode 100644 vllm_ascend/quantization/__init__.py rename vllm_ascend/{quantize => quantization}/quant_config.py (100%) rename vllm_ascend/{quantize => quantization}/quantizer.py (100%) diff --git a/vllm_ascend/quantize/__init__.py b/tests/__init__.py similarity index 100% rename from vllm_ascend/quantize/__init__.py rename to tests/__init__.py diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py new file mode 100644 index 00000000..e3267698 --- /dev/null +++ b/tests/quantization/test_mindie_turbo.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests whether ascend quantization based on MindIE-Turbo is enabled correctly. + +Run `pytest tests/quantization/test_mindie_turbo.py`. +""" + +import pytest + +from .utils import is_mindie_turbo_supported + +MODELS = [ + "LLaMA3-8B_W8A8/", +] + + +@pytest.mark.skipif(not is_mindie_turbo_supported(), + reason="MindIE-Turbo is not installed.") +@pytest.mark.parametrize("model_name_or_path", MODELS) +@pytest.mark.parametrize("max_tokens", [5]) +def test_mindie_turbo( + model_name_or_path: str, + max_tokens: int, +) -> None: + + import vllm # noqa: F401 + from ..conftest import VllmRunner + + import vllm_ascend # noqa: F401 + from vllm_ascend.quantize.quant_config import AscendLinearMethod + + prompt = "What's deep learning?" + example_prompts = [prompt] + + with VllmRunner(model_name_or_path, + max_model_len=8192, + dtype="bfloat16", + enforce_eager=False, + gpu_memory_utilization=0.7) as vllm_model: + + output = vllm_model.generate_greedy(example_prompts, max_tokens) + assert output \ No newline at end of file diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py new file mode 100644 index 00000000..12756bc5 --- /dev/null +++ b/tests/quantization/utils.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +def is_mindie_turbo_supported() -> bool: + try: + import mindie_turbo + except: + return False + + return True \ No newline at end of file diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 42861d1e..405e0483 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -102,7 +102,7 @@ def pre_register_and_update(cls, """ Do some pre-registeration or update action for ascend platform. """ - from vllm_ascend.quantize.quant_config import AscendQuantConfig # noqa: F401 + from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/quantize/quant_config.py b/vllm_ascend/quantization/quant_config.py similarity index 100% rename from vllm_ascend/quantize/quant_config.py rename to vllm_ascend/quantization/quant_config.py diff --git a/vllm_ascend/quantize/quantizer.py b/vllm_ascend/quantization/quantizer.py similarity index 100% rename from vllm_ascend/quantize/quantizer.py rename to vllm_ascend/quantization/quantizer.py From 1bfb206e03f93f99c18b1b91f0481cb4b97d4d7a Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 17:59:59 +0800 Subject: [PATCH 10/26] fix quant description bugs Signed-off-by: angazenn --- vllm_ascend/quantization/quant_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index edb2e172..b9a2f395 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -20,7 +20,7 @@ import torch from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import (register_quantization_config) from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, QuantizeMethodBase) @@ -39,6 +39,7 @@ class AscendQuantConfig(QuantizationConfig): """Config class for Ascend""" def __init__(self, quant_config: Dict[str, Any]): + self.quant_description = quant_config.pop("quant_description") self.quant_config = quant_config self.quantizer = AscendQuantizer.get_quantizer(quant_config) From 9bbc77c2d838a1cd18860b2b706c4da4cae95036 Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 18:06:16 +0800 Subject: [PATCH 11/26] fix ut bugs Signed-off-by: angazenn --- tests/quantization/test_mindie_turbo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py index e3267698..11dee9aa 100644 --- a/tests/quantization/test_mindie_turbo.py +++ b/tests/quantization/test_mindie_turbo.py @@ -43,7 +43,7 @@ def test_mindie_turbo( from ..conftest import VllmRunner import vllm_ascend # noqa: F401 - from vllm_ascend.quantize.quant_config import AscendLinearMethod + from vllm_ascend.quantization.quant_config import AscendLinearMethod prompt = "What's deep learning?" example_prompts = [prompt] From b5d4bf690711e6f0f27c8cbf2128bb5508ea72f2 Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 20:05:54 +0800 Subject: [PATCH 12/26] move quantizer initialization to linear method Signed-off-by: angazenn --- vllm_ascend/quantization/quant_config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index b9a2f395..bd5fcbca 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -41,7 +41,6 @@ class AscendQuantConfig(QuantizationConfig): def __init__(self, quant_config: Dict[str, Any]): self.quant_description = quant_config.pop("quant_description") self.quant_config = quant_config - self.quantizer = AscendQuantizer.get_quantizer(quant_config) def __repr__(self) -> str: return "AscendQuantConfig:\n" + super().__repr__() @@ -120,7 +119,7 @@ class AscendLinearMethod(LinearMethodBase): """ def __init__(self, quant_config: AscendQuantConfig) -> None: - self.quantizer = quant_config.quantizer + self.quantizer = AscendQuantizer.get_quantizer(quant_config.quant_config) self.quant_method = self.quantizer.build_linear_method() def create_weights( From d3dd745d90cad258dbf2cc60bd66afcdab44b7d9 Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 7 Feb 2025 20:14:18 +0800 Subject: [PATCH 13/26] fix bugs Signed-off-by: angazenn --- vllm_ascend/ops/layernorm.py | 2 +- vllm_ascend/quantization/quantizer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 975258b7..2a27224d 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -22,7 +22,7 @@ try: from mindie_turbo import RMSNormWithAntiOutlier -except: +except Exception: pass diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 8c9ce60f..43b3fc76 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -40,7 +40,7 @@ def get_quantizer(cls, quant_config: Dict[str, Any]): enable_rmsnorm_with_antioutlier() return MindIETurboQuantizer.get_quantizer(quant_config) - except: + except Exception: raise NotImplementedError("There is no available ascend quantizer.") def build_linear_method(self): From b809659579fe876b358bd159364ab4b1886211ae Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 10:43:41 +0800 Subject: [PATCH 14/26] narrow down the codes that try-catch structure covers when importing quantizer Signed-off-by: angazenn --- vllm_ascend/quantization/quantizer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 43b3fc76..fb6a2a4d 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -15,6 +15,7 @@ # limitations under the License. # +import importlib from typing import Any, Dict, List from vllm_ascend.ops.layernorm import enable_rmsnorm_with_antioutlier @@ -33,16 +34,16 @@ def get_quantizer(cls, quant_config: Dict[str, Any]): return try: - from mindie_turbo import MindIETurboQuantizer - - # When not using anti-outlier algorithms, "anti_method" refers to an empty string. - if len(quant_config["anti_method"]) > 0: - enable_rmsnorm_with_antioutlier() - - return MindIETurboQuantizer.get_quantizer(quant_config) + importlib.import_module("mindie_turbo.MindIETurboQuantizer") except Exception: raise NotImplementedError("There is no available ascend quantizer.") + # When not using anti-outlier algorithms, "anti_method" refers to an empty string. + if len(quant_config["anti_method"]) > 0: + enable_rmsnorm_with_antioutlier() + + return MindIETurboQuantizer.get_quantizer(quant_config) + def build_linear_method(self): raise NotImplementedError From 7f4b41c4883da47d8e57b7f396759650925bb395 Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 10:46:52 +0800 Subject: [PATCH 15/26] move packages imported to the head Signed-off-by: angazenn --- tests/quantization/test_mindie_turbo.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py index 11dee9aa..31117461 100644 --- a/tests/quantization/test_mindie_turbo.py +++ b/tests/quantization/test_mindie_turbo.py @@ -23,6 +23,12 @@ import pytest +import vllm # noqa: F401 + +import vllm_ascend # noqa: F401 +from vllm_ascend.quantization.quant_config import AscendLinearMethod + +from ..conftest import VllmRunner from .utils import is_mindie_turbo_supported MODELS = [ @@ -39,12 +45,6 @@ def test_mindie_turbo( max_tokens: int, ) -> None: - import vllm # noqa: F401 - from ..conftest import VllmRunner - - import vllm_ascend # noqa: F401 - from vllm_ascend.quantization.quant_config import AscendLinearMethod - prompt = "What's deep learning?" example_prompts = [prompt] From a55390681df74a40933b73f6e8c3021307dfa85e Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 10:55:28 +0800 Subject: [PATCH 16/26] fix import bugs Signed-off-by: angazenn --- tests/quantization/test_mindie_turbo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py index 31117461..fb0b0dec 100644 --- a/tests/quantization/test_mindie_turbo.py +++ b/tests/quantization/test_mindie_turbo.py @@ -28,8 +28,8 @@ import vllm_ascend # noqa: F401 from vllm_ascend.quantization.quant_config import AscendLinearMethod -from ..conftest import VllmRunner -from .utils import is_mindie_turbo_supported +from tests.conftest import VllmRunner +from tests.quantization.utils import is_mindie_turbo_supported MODELS = [ "LLaMA3-8B_W8A8/", From 3b94a9f799bcc8787e95f91559fc773557916947 Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 11:41:26 +0800 Subject: [PATCH 17/26] fix import bugs Signed-off-by: angazenn --- vllm_ascend/quantization/quantizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index fb6a2a4d..dba57bc6 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -34,7 +34,8 @@ def get_quantizer(cls, quant_config: Dict[str, Any]): return try: - importlib.import_module("mindie_turbo.MindIETurboQuantizer") + module = importlib.import_module("mindie_turbo") + MindIETurboQuantizer = module.MindIETurboQuantizer except Exception: raise NotImplementedError("There is no available ascend quantizer.") From 210e6dc90bb275376ca64c6bb19c630916b314e7 Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 17:39:36 +0800 Subject: [PATCH 18/26] remove anti-outlier conditions from vllm_ascend Signed-off-by: angazenn --- vllm_ascend/ops/layernorm.py | 31 --------------------------- vllm_ascend/quantization/quantizer.py | 6 ------ 2 files changed, 37 deletions(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 2a27224d..fced707d 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -20,19 +20,12 @@ import torch from vllm.model_executor.layers.layernorm import RMSNorm -try: - from mindie_turbo import RMSNormWithAntiOutlier -except Exception: - pass - def forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if hasattr(self, "module"): - return self.module.forward_anti_outlier(x, residual) import torch_npu @@ -45,28 +38,4 @@ def forward_oot( return x -def enable_rmsnorm_with_antioutlier(): - def init( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - ) -> None: - super(RMSNorm, self).__init__() - self.hidden_size = hidden_size - self.variance_epsilon = eps - self.variance_size_override = (None if var_hidden_size == hidden_size - else var_hidden_size) - self.has_weight = has_weight - - self.weight = torch.ones(hidden_size) - if self.has_weight: - self.weight = torch.nn.Parameter(self.weight) - - self.module = RMSNormWithAntiOutlier(self.hidden_size) - - RMSNorm.__init__ = init - - RMSNorm.forward_oot = forward_oot diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index dba57bc6..b6e2318b 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -18,8 +18,6 @@ import importlib from typing import Any, Dict, List -from vllm_ascend.ops.layernorm import enable_rmsnorm_with_antioutlier - CUSTOMIZED_QUANTIZER_TYPE: List[str] = [] class AscendQuantizer: @@ -39,10 +37,6 @@ def get_quantizer(cls, quant_config: Dict[str, Any]): except Exception: raise NotImplementedError("There is no available ascend quantizer.") - # When not using anti-outlier algorithms, "anti_method" refers to an empty string. - if len(quant_config["anti_method"]) > 0: - enable_rmsnorm_with_antioutlier() - return MindIETurboQuantizer.get_quantizer(quant_config) def build_linear_method(self): From 639b6022dfe5aabdca91be16ba8e2143e97ea8bb Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 17:49:49 +0800 Subject: [PATCH 19/26] remove unnecessary spaces Signed-off-by: angazenn --- vllm_ascend/ops/layernorm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index fced707d..719aa977 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -26,7 +26,6 @@ def forward_oot( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - import torch_npu if residual is not None: From a92d9fe96345c178c2ea5e385fdf0312dd3f0d6b Mon Sep 17 00:00:00 2001 From: angazenn Date: Sat, 8 Feb 2025 20:31:03 +0800 Subject: [PATCH 20/26] add example quantization codes to ut Signed-off-by: angazenn --- tests/quantization/test_mindie_turbo.py | 14 +++- tests/quantization/utils.py | 93 ++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py index fb0b0dec..1fb9c3b5 100644 --- a/tests/quantization/test_mindie_turbo.py +++ b/tests/quantization/test_mindie_turbo.py @@ -21,6 +21,8 @@ Run `pytest tests/quantization/test_mindie_turbo.py`. """ +import os + import pytest import vllm # noqa: F401 @@ -29,10 +31,10 @@ from vllm_ascend.quantization.quant_config import AscendLinearMethod from tests.conftest import VllmRunner -from tests.quantization.utils import is_mindie_turbo_supported +from tests.quantization.utils import is_mindie_turbo_supported, example_quantization MODELS = [ - "LLaMA3-8B_W8A8/", + "/home/zyj/data/Qwen2.5-0.5B-Instruct/", ] @@ -44,11 +46,17 @@ def test_mindie_turbo( model_name_or_path: str, max_tokens: int, ) -> None: + # vLLM should load weights from disk. Hence we need to save the quantized + # weights at first, and then load it by vLLM. + temp_path = os.path.join(os.path.dirname(__file__), "temp_weight") + if not os.path.exists(temp_path): + os.makedirs(temp_path) + example_quantization(model_name_or_path, temp_path) prompt = "What's deep learning?" example_prompts = [prompt] - with VllmRunner(model_name_or_path, + with VllmRunner(temp_path, max_model_len=8192, dtype="bfloat16", enforce_eager=False, diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 12756bc5..e9d71dcd 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -17,10 +17,101 @@ # limitations under the License. # +import os +import shutil +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier +from msmodelslim.pytorch.llm_ptq.llm_ptq_tools import Calibrator, QuantConfig + + def is_mindie_turbo_supported() -> bool: try: import mindie_turbo except: return False - return True \ No newline at end of file + return True + + +def example_quantization(model_name_or_path: str, tmp_path: str) -> None: + + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=model_name_or_path + ) + + model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_name_or_path, + device_map="npu:0", + torch_dtype="auto" + ).eval() + + data_list = ["What's deep learning?"] + dataset_calib = [] + for calib_data in data_list: + inputs = tokenizer(calib_data, return_tensors='pt').to("npu:0") + dataset_calib.append([inputs.data['input_ids']]) + + anti_config = AntiOutlierConfig(anti_method="m2", dev_type="npu", dev_id=0) + anti_outlier = AntiOutlier(model, calib_data=dataset_calib, cfg=anti_config) + anti_outlier.process() + + disable_names = ['lm_head'] + for layer_index in range(24): + disable_names.append(f'model.layers.{layer_index}.mlp.down_proj') + + quant_config = QuantConfig( + a_bit=8, + w_bit=8, + disable_names=disable_names, + dev_type='npu', + dev_id=0, + act_method=3, + pr=1.0, + w_sym=True, + mm_tensor=False + ) + + calibrator = Calibrator(model, quant_config, calib_data=dataset_calib, disable_level='L0') + calibrator.run() + + # Currently, we need add config.json manualy for quantized weights generated by msmodelslim. + # Following codes will be removed once msmodelslim can generate complete weights + # except 'calibrator.save(tmp_path, save_type=["safe_tensor"])'. + class EmptyModule(torch.nn.Module): + def __init__(self) -> None: + super(EmptyModule, self).__init__() + + def forward(self, x): + return x + + test_quant_config = { + "group_size": 0, + "kv_quant_type": None, + "fa_quant_type": None, + "w_bit": 8, + "a_bit": 8, + "dev_type": "npu", + "fraction": 0.01, + "act_method": 3, + "co_sparse": False, + "anti_method": "m2", + "disable_level": "L0", + "do_smooth": False, + "use_sigma": False, + "sigma_factor": 3.0, + "is_lowbit": False, + "mm_tensor": False, + "w_sym": True, + "open_outlier": True, + "is_dynamic": False + } + + calibrator.model.config.update({"quantization_config": test_quant_config}) + calibrator.model.config.quantization_config.update( + {"quant_description": calibrator.quant_model_json_description.quant_model_description} + ) + + calibrator.save(tmp_path, save_type=["safe_tensor"]) + calibrator.model.save_pretrained(tmp_path, state_dict=EmptyModule().state_dict()) + tokenizer.save_pretrained(tmp_path) \ No newline at end of file From 44737f9c71518bfbfcca53c89fbee6ca9559e6b8 Mon Sep 17 00:00:00 2001 From: angazenn Date: Mon, 10 Feb 2025 13:26:01 +0800 Subject: [PATCH 21/26] fix bugs Signed-off-by: angazenn --- tests/quantization/test_mindie_turbo.py | 4 ++-- tests/quantization/utils.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py index 1fb9c3b5..700a65ff 100644 --- a/tests/quantization/test_mindie_turbo.py +++ b/tests/quantization/test_mindie_turbo.py @@ -34,7 +34,7 @@ from tests.quantization.utils import is_mindie_turbo_supported, example_quantization MODELS = [ - "/home/zyj/data/Qwen2.5-0.5B-Instruct/", + "Qwen/Qwen2.5-0.5B-Instruct", ] @@ -46,7 +46,7 @@ def test_mindie_turbo( model_name_or_path: str, max_tokens: int, ) -> None: - # vLLM should load weights from disk. Hence we need to save the quantized + # vLLM must load weights from disk. Hence we need to save the quantized # weights at first, and then load it by vLLM. temp_path = os.path.join(os.path.dirname(__file__), "temp_weight") if not os.path.exists(temp_path): diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index e9d71dcd..657392d7 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,8 +1,6 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py -# Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f1e955673729db89175f5c15336545eb6bfa4788 Mon Sep 17 00:00:00 2001 From: angazenn Date: Mon, 10 Feb 2025 15:58:12 +0800 Subject: [PATCH 22/26] move importation of AscendQuantConfig Signed-off-by: angazenn --- vllm_ascend/__init__.py | 4 ++++ vllm_ascend/platform.py | 10 ---------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 80af5a52..7d6c1784 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -18,4 +18,8 @@ def register(): """Register the NPU platform.""" + # TODO: https://github.com/vllm-project/vllm/pull/12432 Once this pr is merged, + # the following module can be imported using pre_register_and_update function. + from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 + return "vllm_ascend.platform.NPUPlatform" diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 405e0483..510e1d22 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -94,16 +94,6 @@ def synchronize(cls): def mem_get_info(cls) -> Tuple[int, int]: return torch.npu.mem_get_info() - # Relies on this pull request https://github.com/vllm-project/vllm/pull/12432. - @classmethod - def pre_register_and_update(cls, - parser: Optional[FlexibleArgumentParser] = None - ) -> None: - """ - Do some pre-registeration or update action for ascend platform. - """ - from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 - @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Register ops when setup. From 487dc2b3586bccb52fa25908a419a34dd1400543 Mon Sep 17 00:00:00 2001 From: angazenn Date: Tue, 11 Feb 2025 10:58:40 +0800 Subject: [PATCH 23/26] adapt to new config.json Signed-off-by: angazenn --- tests/quantization/utils.py | 27 +----------------------- vllm_ascend/quantization/quant_config.py | 10 ++++----- 2 files changed, 6 insertions(+), 31 deletions(-) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 657392d7..fa453b54 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -83,32 +83,7 @@ def __init__(self) -> None: def forward(self, x): return x - test_quant_config = { - "group_size": 0, - "kv_quant_type": None, - "fa_quant_type": None, - "w_bit": 8, - "a_bit": 8, - "dev_type": "npu", - "fraction": 0.01, - "act_method": 3, - "co_sparse": False, - "anti_method": "m2", - "disable_level": "L0", - "do_smooth": False, - "use_sigma": False, - "sigma_factor": 3.0, - "is_lowbit": False, - "mm_tensor": False, - "w_sym": True, - "open_outlier": True, - "is_dynamic": False - } - - calibrator.model.config.update({"quantization_config": test_quant_config}) - calibrator.model.config.quantization_config.update( - {"quant_description": calibrator.quant_model_json_description.quant_model_description} - ) + calibrator.model.config.quantization_config = calibrator.quant_model_json_description.quant_model_description calibrator.save(tmp_path, save_type=["safe_tensor"]) calibrator.model.save_pretrained(tmp_path, state_dict=EmptyModule().state_dict()) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index bd5fcbca..144cfaef 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -1,6 +1,7 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. +# Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +19,7 @@ from typing import Any, Dict, List, Optional import torch +import torch_npu from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -39,8 +41,7 @@ class AscendQuantConfig(QuantizationConfig): """Config class for Ascend""" def __init__(self, quant_config: Dict[str, Any]): - self.quant_description = quant_config.pop("quant_description") - self.quant_config = quant_config + self.quant_description = quant_config def __repr__(self) -> str: return "AscendQuantConfig:\n" + super().__repr__() @@ -68,8 +69,7 @@ def from_config(cls, config: Dict[str, Any]) -> "AscendQuantConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - dev_type = hf_quant_cfg.get("dev_type", None) - if dev_type == "npu": + if torch.npu.is_available(): return "ascend" return None @@ -119,7 +119,7 @@ class AscendLinearMethod(LinearMethodBase): """ def __init__(self, quant_config: AscendQuantConfig) -> None: - self.quantizer = AscendQuantizer.get_quantizer(quant_config.quant_config) + self.quantizer = AscendQuantizer.get_quantizer(quant_config.quant_description) self.quant_method = self.quantizer.build_linear_method() def create_weights( From 90531d949a3be810c0b1972a292c9421a26c6bbb Mon Sep 17 00:00:00 2001 From: angazenn Date: Tue, 11 Feb 2025 11:46:02 +0800 Subject: [PATCH 24/26] move import of AsecndQuantConfig to pre_register Signed-off-by: angazenn --- vllm_ascend/__init__.py | 4 ---- vllm_ascend/platform.py | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 7d6c1784..80af5a52 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -18,8 +18,4 @@ def register(): """Register the NPU platform.""" - # TODO: https://github.com/vllm-project/vllm/pull/12432 Once this pr is merged, - # the following module can be imported using pre_register_and_update function. - from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 - return "vllm_ascend.platform.NPUPlatform" diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 510e1d22..dd15537c 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -94,6 +94,21 @@ def synchronize(cls): def mem_get_info(cls) -> Tuple[int, int]: return torch.npu.mem_get_info() + @classmethod + def pre_register_and_update(cls, + parser: Optional[FlexibleArgumentParser] = None + ) -> None: + """ + Do some pre-registeration or update action for the current platform. + This function is called before global VllmConfig is initialized or cli + arguments are parsed. It's used for out-of-tree platforms to register or + update the configuration. + For example, the out-of-tree quantization config can be imported and + registered here dynamically. + """ + + from vllm_ascend.quantization.quant_config import AscendQuantConfig # noqa: F401 + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # Register ops when setup. From 2c38b410ad5fe2c506d9c5979f519c477b3fb71c Mon Sep 17 00:00:00 2001 From: angazenn Date: Tue, 11 Feb 2025 14:13:39 +0800 Subject: [PATCH 25/26] clean code Signed-off-by: angazenn --- tests/quantization/test_mindie_turbo.py | 1 - tests/quantization/utils.py | 6 ++---- vllm_ascend/quantization/quant_config.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/quantization/test_mindie_turbo.py b/tests/quantization/test_mindie_turbo.py index 700a65ff..8e48735b 100644 --- a/tests/quantization/test_mindie_turbo.py +++ b/tests/quantization/test_mindie_turbo.py @@ -28,7 +28,6 @@ import vllm # noqa: F401 import vllm_ascend # noqa: F401 -from vllm_ascend.quantization.quant_config import AscendLinearMethod from tests.conftest import VllmRunner from tests.quantization.utils import is_mindie_turbo_supported, example_quantization diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index fa453b54..40ccfa6b 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -15,8 +15,6 @@ # limitations under the License. # -import os -import shutil import torch from transformers import AutoTokenizer, AutoModelForCausalLM from msmodelslim.pytorch.llm_ptq.anti_outlier import AntiOutlierConfig, AntiOutlier @@ -25,8 +23,8 @@ def is_mindie_turbo_supported() -> bool: try: - import mindie_turbo - except: + import mindie_turbo # noqa: F401 + except Exception: return False return True diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 144cfaef..a8be81d0 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional import torch -import torch_npu +import torch_npu # noqa: F401 from vllm.logger import init_logger from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, From dbc7ca255d69844b17078bc65dd90b1118bc2694 Mon Sep 17 00:00:00 2001 From: angazenn Date: Fri, 21 Feb 2025 09:59:08 +0800 Subject: [PATCH 26/26] add int8 cache dtype when using attention quantization Signed-off-by: angazenn --- vllm_ascend/worker.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/vllm_ascend/worker.py b/vllm_ascend/worker.py index c5884e36..2323e439 100644 --- a/vllm_ascend/worker.py +++ b/vllm_ascend/worker.py @@ -101,6 +101,17 @@ def __init__( not in ["medusa", "mlp_speculator", "eagle"]) \ else {"return_hidden_states": True} + if vllm_config.quant_config is not None and \ + 'fa_quant_type' in vllm_config.quant_config.quant_description.keys(): + # using ascend attention quant. + # TODO: Updates of cache_config should de added into + # NPUPlatorm.check_and_update_config. However, this function fails to + # update STR_DTYPE_TO_TORCH_DTYPE which is used by vLLM 0.7.1 to convert + # dtype string to torch.dtype. Hence we have to move these codes to here. + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + cache_config.cache_dtype = 'int8' + STR_DTYPE_TO_TORCH_DTYPE['int8'] = torch.int8 + ModelRunnerClass: Type[ModelRunnerBase] = NPUModelRunner if model_config.runner_type == "pooling": ModelRunnerClass = PoolingModelRunner