diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 433e83f..674e100 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -770,13 +770,9 @@ def forward( num_blocks, block_size, self.num_kv_heads, self.qk_rope_head_dim + self.kv_lora_rank) slots = attn_metadata.slot_mapping - torch_npu.npu_reshapecache(key=k_cache, - value=None, - keyCache=key_cache, - valueCache=None, - slotMapping=slots, - compressType=0, - kvCacheCfg=1) + torch_npu._npu_reshape_and_cache_siso(key=k_cache, + key_cache=key_cache, + slot_indices=slots) if attn_metadata.num_prefills > 0: attn_output = torch.empty(num_tokens, @@ -793,32 +789,16 @@ def forward( self.seq_lens_tensor_cpu = torch.from_numpy( np.array(attn_metadata.prefill_metadata.seq_lens).astype( np.int32)) - torch_npu.npu_selfattention(query=query, - key=key, - value=value, - kvcacheCfg=0, - mask=mask, - maskType=1, - isTriuMask=0, - seqLen=self.seq_lens_tensor_cpu, - scale=self.scale, - qScale=1, - scaleType=0, - headNum=self.num_heads, - kvHeadNum=self.num_heads, - mlaVHeadSize=0, - calcType=3, - kernelType=0, - clampType=0, - quantType=0, - cacheType=0, - windowSize=0, - clampMin=0, - clampMax=0, - batchRunStatusEnable=False, - inputLayout=0, - outDataType=0, - out=attn_output) + torch_npu._npu_flash_attention( + query=query, + key=key, + value=value, + mask=mask, + seq_len=self.seq_lens_tensor_cpu, + scale_value=self.scale, + num_heads=self.num_heads, + num_kv_heads=self.num_heads, + out=attn_output) else: # TODO: Will support prefix cache and chunked prefill soon. raise RuntimeError( @@ -835,25 +815,16 @@ def forward( np.array(attn_metadata.decode_metadata.seq_lens).astype( np.int32)) block_tables = attn_metadata.decode_metadata.block_tables - torch_npu.npu_pagedattention(query=query, - keyCache=key_cache, - valueCache=None, - contextLens=self.seq_lens_tensor_cpu, - maskType=0, - kvHeadNum=self.num_kv_heads, - headNum=self.num_heads, - mlaVHeadSize=self.kv_lora_rank, - qkScale=self.scale, - blockTables=block_tables, - batchRunStatusEnable=False, - hasQuantOffset=False, - compressType=0, - calcType=0, - scaleType=0, - quantType=0, - inputLayout=0, - outDataType=-1, - attnOut=attn_output) + torch_npu._npu_paged_attention_mla( + query=query, + key_cache=key_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=block_tables, + context_lens=self.seq_lens_tensor_cpu, + mla_vheadsize=self.kv_lora_rank, + out=attn_output) attn_output_t = torch.transpose(attn_output, 0, 1) attn_output_t = torch.bmm(attn_output_t, self.w_vc) attn_output = torch.transpose(attn_output_t, 0, 1) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c2d4146..f3e47ab 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -50,10 +50,9 @@ def group_topk(hidden_states: torch.Tensor, topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group - torch_npu.npu_group_topk(input=scores, - out=scores, - group_num=num_expert_group, - k=topk_group) + torch_npu._npu_group_topk(self=scores, + k=topk_group, + group_num=num_expert_group) if e_score_correction_bias is not None: topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1] # Use original unbiased scores for the routing weights