From d0b3cb4fa79d5fc7f8245a3c68885ce1fa030ba4 Mon Sep 17 00:00:00 2001 From: Yaphets24 <44045681+Yaphets24@users.noreply.github.com> Date: Sat, 22 Feb 2025 17:43:42 +0800 Subject: [PATCH] modify:Eliminate redundant operations in the code to improve performance (#137) ### What this PR does / why we need it? Eliminate redundant operations in the code to improve performance ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed --------- Signed-off-by: Yaphets24 Signed-off-by: MengqingCao Co-authored-by: MengqingCao --- vllm_ascend/attention.py | 36 +++++++++------------------ vllm_ascend/model_runner.py | 2 ++ vllm_ascend/ops/fused_moe.py | 5 ++-- vllm_ascend/ops/rotary_embedding.py | 38 ++++++++++++++++++++++++++++- 4 files changed, 52 insertions(+), 29 deletions(-) diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 3088efb5..66bc45e1 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -742,30 +742,20 @@ def forward( self.qk_head_dim) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if attn_metadata.num_prefills > 0: - assert attn_metadata.prefill_metadata is not None - assert attn_metadata.prefill_metadata.seq_lens is not None - np_positions = np.concatenate([ - np.arange(i) for i in attn_metadata.prefill_metadata.seq_lens - ]) - positions = torch.tensor(np_positions, - device=hidden_states_or_q_c.device) - else: - assert attn_metadata.decode_metadata is not None - np_positions = np.array(attn_metadata.decode_metadata.seq_lens) - 1 - positions = torch.tensor(np_positions, - device=hidden_states_or_q_c.device) + k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1) if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding': ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape q_pe = q_pe.reshape(num_tokens, -1) k_pe = k_pe.reshape(num_tokens, -1) - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, + k_pe) q_pe = q_pe.view(ori_q_pe_shape) k_pe = k_pe.view(ori_k_pe_shape) else: - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, + k_pe) if self.w_kc is None or self.w_vc is None: kv_b_proj_weight = self.kv_b_proj.weight.reshape( @@ -786,16 +776,14 @@ def forward( k_cache = torch.cat( [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2) - k_pe = k_pe.repeat(1, self.num_heads, 1) + k_pe = k_pe.expand(-1, self.num_heads, -1) key = torch.cat([k_nope.view(num_tokens, kv_heads_num, -1), k_pe], dim=2) else: kv_heads_num = self.num_kv_heads - q_nope_t = torch_npu.npu_transpose(q_nope, (1, 0, 2), - require_contiguous=True) + q_nope_t = torch.transpose(q_nope, 0, 1) q_nope_out = torch.bmm(q_nope_t, self.w_kc) - q_nope = torch_npu.npu_transpose(q_nope_out, (1, 0, 2), - require_contiguous=True) + q_nope = torch.transpose(q_nope_out, 0, 1) k_cache = torch.cat( [kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2) @@ -895,12 +883,10 @@ def forward( inputLayout=0, outDataType=-1, attnOut=attn_output) - attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2), - require_contiguous=True) + attn_output_t = torch.transpose(attn_output, 0, 1) attn_output_t = torch.bmm(attn_output_t, self.w_vc) - attn_output = torch_npu.npu_transpose(attn_output_t, (1, 0, 2), - require_contiguous=True) + attn_output = torch.transpose(attn_output_t, 0, 1) - output, _ = self.o_proj(attn_output.view(num_tokens, -1)) + output, _ = self.o_proj(attn_output.reshape(num_tokens, -1)) return output diff --git a/vllm_ascend/model_runner.py b/vllm_ascend/model_runner.py index 2bb057fb..d0aa06dc 100644 --- a/vllm_ascend/model_runner.py +++ b/vllm_ascend/model_runner.py @@ -1137,6 +1137,8 @@ def execute_model( if not bypass_model_exec: with set_forward_context(model_input.attn_metadata, self.vllm_config, virtual_engine): + if model_input.attn_metadata is not None: + model_input.attn_metadata.input_positions = model_input.input_positions hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index cbb86224..db03509f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -65,7 +65,7 @@ def group_topk(hidden_states: torch.Tensor, if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + return topk_weights, topk_ids.to(torch.int32) def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -126,13 +126,12 @@ def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, down_out_list = torch.cat(down_out_list, dim=0) # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. - routing_weights = topk_weights.to(down_out_list.dtype) hidden_states = torch_npu.npu_moe_finalize_routing( down_out_list, skip1=None, skip2=None, bias=None, - scales=routing_weights, + scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids) if len(ori_shape) == 3: diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 2279ad15..1999386b 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -18,7 +18,8 @@ from typing import Optional, Tuple import torch -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import ( + DeepseekScalingRotaryEmbedding, RotaryEmbedding) def rope_forward_oot( @@ -49,8 +50,43 @@ def rope_forward_oot( self.cos_sin_cache, self.is_neox_style, ) + return query, key + + +def rope_deepseek_forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + import torch_npu + + if self.cos_sin_cache.device != query.device: + self.cos_sin_cache = self.cos_sin_cache.to(query.device) + if self.cos_sin_cache.dtype != query.dtype: + self.cos_sin_cache = self.cos_sin_cache.to(query.dtype) + if offsets is not None: + raise NotImplementedError( + "Batched rotary embedding is currently not supported on NPU.") + else: + # TODO: Remove the contiguous in the future. + ori_query_shape, ori_key_shape = query.shape, key.shape + query = query.contiguous().view(query.shape[0], -1) + key = key.contiguous().view(query.shape[0], -1) + torch_npu.npu_rope( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + query = query.view(ori_query_shape) + key = key.view(ori_key_shape) return query, key RotaryEmbedding.forward_oot = rope_forward_oot +DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot