Skip to content

Commit

Permalink
code format
Browse files Browse the repository at this point in the history
Signed-off-by: MengqingCao <cmq0113@163.com>
  • Loading branch information
MengqingCao committed Feb 22, 2025
1 parent e56a23d commit 9ed176e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
18 changes: 10 additions & 8 deletions vllm_ascend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,18 +742,20 @@ def forward(
self.qk_head_dim)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)

k_pe = k_pe.view(num_tokens, self.num_kv_heads, -1)

if self.rotary_emb.__class__.__name__ == 'RotaryEmbedding':
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(num_tokens, -1)
k_pe = k_pe.reshape(num_tokens, -1)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
q_pe = q_pe.view(ori_q_pe_shape)
k_pe = k_pe.view(ori_k_pe_shape)
else:
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)

if self.w_kc is None or self.w_vc is None:
kv_b_proj_weight = self.kv_b_proj.weight.reshape(
Expand All @@ -779,9 +781,9 @@ def forward(
dim=2)
else:
kv_heads_num = self.num_kv_heads
q_nope_t = torch.transpose(q_nope, 0,1)
q_nope_t = torch.transpose(q_nope, 0, 1)
q_nope_out = torch.bmm(q_nope_t, self.w_kc)
q_nope = torch.transpose(q_nope_out,0,1)
q_nope = torch.transpose(q_nope_out, 0, 1)
k_cache = torch.cat(
[kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe],
dim=2)
Expand Down Expand Up @@ -881,10 +883,10 @@ def forward(
inputLayout=0,
outDataType=-1,
attnOut=attn_output)
attn_output_t = torch.transpose(attn_output, 0,1)
attn_output_t = torch.transpose(attn_output, 0, 1)
attn_output_t = torch.bmm(attn_output_t, self.w_vc)
attn_output = torch.transpose(attn_output_t, 0,1)
attn_output = torch.transpose(attn_output_t, 0, 1)

output, _ = self.o_proj(attn_output.reshape(num_tokens, -1))

return output
return output
13 changes: 8 additions & 5 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from typing import Optional, Tuple

import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding , DeepseekScalingRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)


def rope_forward_oot(
Expand Down Expand Up @@ -51,6 +52,7 @@ def rope_forward_oot(
)
return query, key


def rope_deepseek_forward_oot(
self,
positions: torch.Tensor,
Expand All @@ -69,9 +71,9 @@ def rope_deepseek_forward_oot(
"Batched rotary embedding is currently not supported on NPU.")
else:
# TODO: Remove the contiguous in the future.
ori_query_shape,ori_key_shape = query.shape,key.shape
query = query.contiguous().view(query.shape[0],-1)
key = key.contiguous().view(query.shape[0],-1)
ori_query_shape, ori_key_shape = query.shape, key.shape
query = query.contiguous().view(query.shape[0], -1)
key = key.contiguous().view(query.shape[0], -1)
torch_npu.npu_rope(
positions,
query,
Expand All @@ -85,5 +87,6 @@ def rope_deepseek_forward_oot(

return query, key


RotaryEmbedding.forward_oot = rope_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot
DeepseekScalingRotaryEmbedding.forward = rope_deepseek_forward_oot

0 comments on commit 9ed176e

Please sign in to comment.