diff --git a/examples/offline_distributed_inference_npu.py b/examples/offline_distributed_inference_npu.py index f8d5489a..8e503ad2 100644 --- a/examples/offline_distributed_inference_npu.py +++ b/examples/offline_distributed_inference_npu.py @@ -29,11 +29,10 @@ # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, temperature=0.0) # Create an LLM. -# TODO (cmq): ray is not supported currently, need some fixes llm = LLM( model="facebook/opt-125m", tensor_parallel_size=2, - distributed_executor_backend="mp", + distributed_executor_backend="ray", trust_remote_code=True, ) diff --git a/vllm_ascend/attention.py b/vllm_ascend/attention.py index 0a014e68..2f9b5e70 100644 --- a/vllm_ascend/attention.py +++ b/vllm_ascend/attention.py @@ -457,9 +457,7 @@ def __init__( self.kv_cache_dtype = kv_cache_dtype self.sliding_window = sliding_window if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, - dtype=torch.float32, - device="npu") + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes self.attn_type = attn_type @@ -520,7 +518,7 @@ def forward( attn_metadata.sparse_mode = 2 attention_mask = gen_input_mask( attn_metadata.max_prefill_seq_len, self.sliding_window, - num_tokens) + num_tokens, query.device) attn_metadata.attn_mask = attention_mask if (self.alibi_slopes is not None @@ -531,6 +529,7 @@ def forward( dtype=query.dtype, seq_len=attn_metadata.max_prefill_seq_len, batch_size=num_tokens, + device=query.device, ) if (len(kv_cache) == 0 or attn_metadata.block_tables is None @@ -571,7 +570,7 @@ def forward( query = query.view(query.shape[0], -1, self.num_heads * self.head_size) output = torch.zeros(query.shape, - device="npu", + device=query.device, dtype=query.dtype) # TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention # support only when `S == 1`, OPTIMIZE ME when prefix caching @@ -621,7 +620,7 @@ def forward( return output -def gen_input_mask(seq_len, sliding_window, len): +def gen_input_mask(seq_len, sliding_window, len, device): """ Generating lower triangular matrix """ @@ -630,7 +629,7 @@ def gen_input_mask(seq_len, sliding_window, len): global SHARE_MASK_TRIL_PREFIX_CACHE if SHARE_MASK_TRIL_PREFIX_CACHE is None: SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu( - torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"), + torch.ones(1, 1, 2048, 2048, dtype=bool, device=device), diagonal=1, ) attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE @@ -638,7 +637,7 @@ def gen_input_mask(seq_len, sliding_window, len): global SHARE_MASK_TRIL if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len: SHARE_MASK_TRIL = ~torch.tril( - torch.ones(seq_len, seq_len, dtype=bool, device="npu")) + torch.ones(seq_len, seq_len, dtype=bool, device=device)) attention_mask = SHARE_MASK_TRIL if sliding_window is not None: @@ -656,8 +655,10 @@ def _make_alibi_bias( dtype: torch.dtype, seq_len: int, batch_size: int, + device: torch.device, ): - bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + alibi_slopes = alibi_slopes.to(device) + bias = torch.arange(seq_len, dtype=dtype, device=device) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but @@ -674,7 +675,7 @@ def _make_alibi_bias( num_heads, seq_len, padded_len, - device=alibi_slopes.device, + device=device, dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None])