Skip to content

Commit

Permalink
[attn] fix device of tensors in attention (#25)
Browse files Browse the repository at this point in the history
### What this PR does / why we need it?
Fix device of tensors created in `AscendAttentionBackendImpl`.

While specifying device to cards except card-0, there'll cause an
**device conflict** because the tensors (such as `attn_mask`) will be
put on card-0 by default.

This pr creates these tensors on the correct card corresponding to the
input.

### Does this PR introduce _any_ user-facing change?
User could specify device with local rank by this pr, and a modify on
vLLM is also needed, will related to this pr when created.

### How was this patch tested?
This is tested by the following code locally. Will add a test case when
the modify in vLLM is also completed.
```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="~/.cache/modelscope/hub/Qwen/Qwen2___5-7B-Instruct", device="npu:1")

# Generate texts from the prompts.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

Signed-off-by: MengqingCao <cmq0113@163.com>
  • Loading branch information
MengqingCao authored Feb 10, 2025
1 parent c59375c commit 7006835
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
3 changes: 1 addition & 2 deletions examples/offline_distributed_inference_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
# TODO (cmq): ray is not supported currently, need some fixes
llm = LLM(
model="facebook/opt-125m",
tensor_parallel_size=2,
distributed_executor_backend="mp",
distributed_executor_backend="ray",
trust_remote_code=True,
)

Expand Down
21 changes: 11 additions & 10 deletions vllm_ascend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,7 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes,
dtype=torch.float32,
device="npu")
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.attn_type = attn_type

Expand Down Expand Up @@ -520,7 +518,7 @@ def forward(
attn_metadata.sparse_mode = 2
attention_mask = gen_input_mask(
attn_metadata.max_prefill_seq_len, self.sliding_window,
num_tokens)
num_tokens, query.device)
attn_metadata.attn_mask = attention_mask

if (self.alibi_slopes is not None
Expand All @@ -531,6 +529,7 @@ def forward(
dtype=query.dtype,
seq_len=attn_metadata.max_prefill_seq_len,
batch_size=num_tokens,
device=query.device,
)

if (len(kv_cache) == 0 or attn_metadata.block_tables is None
Expand Down Expand Up @@ -571,7 +570,7 @@ def forward(
query = query.view(query.shape[0], -1,
self.num_heads * self.head_size)
output = torch.zeros(query.shape,
device="npu",
device=query.device,
dtype=query.dtype)
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
# support only when `S == 1`, OPTIMIZE ME when prefix caching
Expand Down Expand Up @@ -621,7 +620,7 @@ def forward(
return output


def gen_input_mask(seq_len, sliding_window, len):
def gen_input_mask(seq_len, sliding_window, len, device):
"""
Generating lower triangular matrix
"""
Expand All @@ -630,15 +629,15 @@ def gen_input_mask(seq_len, sliding_window, len):
global SHARE_MASK_TRIL_PREFIX_CACHE
if SHARE_MASK_TRIL_PREFIX_CACHE is None:
SHARE_MASK_TRIL_PREFIX_CACHE = torch.triu(
torch.ones(1, 1, 2048, 2048, dtype=bool, device="npu"),
torch.ones(1, 1, 2048, 2048, dtype=bool, device=device),
diagonal=1,
)
attention_mask = SHARE_MASK_TRIL_PREFIX_CACHE
else:
global SHARE_MASK_TRIL
if SHARE_MASK_TRIL is None or SHARE_MASK_TRIL.shape[0] < seq_len:
SHARE_MASK_TRIL = ~torch.tril(
torch.ones(seq_len, seq_len, dtype=bool, device="npu"))
torch.ones(seq_len, seq_len, dtype=bool, device=device))

attention_mask = SHARE_MASK_TRIL
if sliding_window is not None:
Expand All @@ -656,8 +655,10 @@ def _make_alibi_bias(
dtype: torch.dtype,
seq_len: int,
batch_size: int,
device: torch.device,
):
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
alibi_slopes = alibi_slopes.to(device)
bias = torch.arange(seq_len, dtype=dtype, device=device)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
Expand All @@ -674,7 +675,7 @@ def _make_alibi_bias(
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
device=device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
Expand Down

0 comments on commit 7006835

Please sign in to comment.