diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 5b72ffbf..bb5cb27b 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -15,5 +15,4 @@ # limitations under the License. # -from vllm_ascend.patch import patch_cache_dtype # noqa -from vllm_ascend.patch import patch_attention # noqa \ No newline at end of file +from vllm_ascend.patch import patch_attention, patch_cache_dtype # noqa \ No newline at end of file diff --git a/vllm_ascend/patch/patch_attention.py b/vllm_ascend/patch/patch_attention.py index 8a0e8da7..e43b547c 100644 --- a/vllm_ascend/patch/patch_attention.py +++ b/vllm_ascend/patch/patch_attention.py @@ -18,7 +18,7 @@ # # This file is used to monkey patch vLLM Attention.__init__ function # and move the instantiation of num_heads, head_size, num_kv_heads -# ahead of the initialization of attention quant methods, which is +# ahead of the initialization of attention quant methods, which is # required by ascend attention quant method to initialize. # Remove this file when vllm support it. @@ -34,128 +34,127 @@ QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op def attention_init( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - **extra_impl_args, - ) -> None: - super(Attention, self).__init__() - if per_layer_sliding_window is not None: - # per-layer sliding window - sliding_window = per_layer_sliding_window - elif cache_config is not None: - # model-level sliding window - sliding_window = cache_config.sliding_window - else: - sliding_window = None - - if cache_config is not None: - kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size - is_attention_free = cache_config.is_attention_free - calculate_kv_scales = cache_config.calculate_kv_scales - else: - kv_cache_dtype = "auto" - block_size = 16 - is_attention_free = False - calculate_kv_scales = False - if num_kv_heads is None: - num_kv_heads = num_heads - - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep the float32 versions of k/v_scale for attention - # backends that don't support tensors (Flashinfer) - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # should move following three lines before quant method is instantiated. - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - - quant_method = quant_config.get_quant_method( - self, prefix=prefix) if quant_config else None - if quant_method is not None: - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) - - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - is_attention_free, - blocksparse_params is not None, - use_mla=use_mla) - impl_cls = attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, - **extra_impl_args) - self.sliding_window = sliding_window - self.backend = backend_name_to_enum(attn_backend.get_name()) - self.dtype = dtype - - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how - # torch.compile works by registering the attention as one giant - # opaque custom op. For other platforms, we directly call them - # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() - - self.use_output = attn_backend.accept_output_buffer - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix - self.attn_type = attn_type - # use a placeholder kv cache tensor during init, which will be replaced - # by bind_kv_cache - # this variable will not be accessed if use_direct_call is True - self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) - ] - - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, + use_mla: bool = False, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + **extra_impl_args, +) -> None: + super(Attention, self).__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales + else: + kv_cache_dtype = "auto" + block_size = 16 + is_attention_free = False + calculate_kv_scales = False + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + # should move following three lines before quant method is instantiated. + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None: + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype, + block_size, + is_attention_free, + blocksparse_params is not None, + use_mla=use_mla) + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **extra_impl_args) + self.sliding_window = sliding_window + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = not current_platform.is_cuda_alike( + ) and not current_platform.is_cpu() + + self.use_output = attn_backend.accept_output_buffer + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix + self.attn_type = attn_type + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True + self.kv_cache = [ + torch.tensor([]) for _ in range( + get_current_vllm_config().parallel_config.pipeline_parallel_size) + ] + + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) Attention.__init__ = attention_init \ No newline at end of file