Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Nov 11, 2024
1 parent 5fdfca5 commit beb8ee4
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 78 deletions.
7 changes: 3 additions & 4 deletions test/test_flash_attention_backward.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
import unittest
import pytest

from flash_attn import flash_attn_func
import numpy as np
import torch
import torch.nn.functional as F
import pytest
import torch_xla
import torch_xla.core.xla_model as xm

from flash_attn import flash_attn_func
import flash_attn_2_cuda as flash_attn_cuda
import torchacc as ta

Expand Down
7 changes: 3 additions & 4 deletions test/test_flash_attention_forward.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import os
import unittest
import pytest

from flash_attn import flash_attn_func
import numpy as np
import torch
import torch.nn.functional as F
import pytest
import torch_xla
import torch_xla.core.xla_model as xm

from flash_attn import flash_attn_func
import torchacc as ta


Expand Down
11 changes: 5 additions & 6 deletions test/test_flash_attention_varlen_backward.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os, sys
import unittest
import os
import pytest

from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
import numpy as np
import torch
import torch.nn.functional as F
import pytest
import torch_xla
import torch_xla.core.xla_model as xm

from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
import flash_attn_2_cuda as flash_attn_cuda
import torchacc as ta

Expand Down
13 changes: 6 additions & 7 deletions test/test_flash_attention_varlen_forward.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os, sys
import unittest
import os
import pytest

from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
import numpy as np
import torch
import torch.nn.functional as F
import pytest
import torchacc as ta
import torch_xla

from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
import torchacc as ta


def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
Expand Down
32 changes: 13 additions & 19 deletions torch_xla/csrc/ops/flash_attention_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,18 @@ namespace torch_xla {

class FlashAttentionBackward : public XlaNode {
public:
FlashAttentionBackward(const torch::lazy::Value& dout,
const torch::lazy::Value& q,
const torch::lazy::Value& k,
const torch::lazy::Value& v,
const torch::lazy::Value& out,
const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& rng_state,
const std::string params);

FlashAttentionBackward(const torch::lazy::Value& dout,
const torch::lazy::Value& q,
const torch::lazy::Value& k,
const torch::lazy::Value& v,
const torch::lazy::Value& out,
const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& rng_state,
const torch::lazy::Value& alibi_slopes,
const std::string params);
FlashAttentionBackward(
const torch::lazy::Value& dout, const torch::lazy::Value& q,
const torch::lazy::Value& k, const torch::lazy::Value& v,
const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& rng_state, const std::string params);

FlashAttentionBackward(
const torch::lazy::Value& dout, const torch::lazy::Value& q,
const torch::lazy::Value& k, const torch::lazy::Value& v,
const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& rng_state,
const torch::lazy::Value& alibi_slopes, const std::string params);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

Expand All @@ -37,4 +31,4 @@ class FlashAttentionBackward : public XlaNode {

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_VARLEN_BACKWARD_H_
#endif // XLA_TORCH_XLA_CSRC_OPS_FLASH_ATTENTION_BACKWARD_H_
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/flash_attention_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class FlashAttentionForward : public XlaNode {
public:
FlashAttentionForward(const torch::lazy::Value& q,
const torch::lazy::Value& k,
const torch::lazy::Value& v,
const std::string params);
const torch::lazy::Value& v, const std::string params);

FlashAttentionForward(const torch::lazy::Value& q,
const torch::lazy::Value& k,
Expand Down
38 changes: 16 additions & 22 deletions torch_xla/csrc/ops/flash_attention_varlen_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,22 @@ namespace torch_xla {

class FlashAttentionVarlenBackward : public XlaNode {
public:
FlashAttentionVarlenBackward(const torch::lazy::Value& dout,
const torch::lazy::Value& q,
const torch::lazy::Value& k,
const torch::lazy::Value& v,
const torch::lazy::Value& out,
const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& cu_seqlens_q,
const torch::lazy::Value& cu_seqlens_k,
const torch::lazy::Value& rng_state,
const std::string params);

FlashAttentionVarlenBackward(const torch::lazy::Value& dout,
const torch::lazy::Value& q,
const torch::lazy::Value& k,
const torch::lazy::Value& v,
const torch::lazy::Value& out,
const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& cu_seqlens_q,
const torch::lazy::Value& cu_seqlens_k,
const torch::lazy::Value& rng_state,
const torch::lazy::Value& alibi_slopes,
const std::string params);
FlashAttentionVarlenBackward(
const torch::lazy::Value& dout, const torch::lazy::Value& q,
const torch::lazy::Value& k, const torch::lazy::Value& v,
const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& cu_seqlens_q,
const torch::lazy::Value& cu_seqlens_k,
const torch::lazy::Value& rng_state, const std::string params);

FlashAttentionVarlenBackward(
const torch::lazy::Value& dout, const torch::lazy::Value& q,
const torch::lazy::Value& k, const torch::lazy::Value& v,
const torch::lazy::Value& out, const torch::lazy::Value& softmax_lse,
const torch::lazy::Value& cu_seqlens_q,
const torch::lazy::Value& cu_seqlens_k,
const torch::lazy::Value& rng_state,
const torch::lazy::Value& alibi_slopes, const std::string params);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ xla::Shape NodeOutputShape(const torch::lazy::Value& q) {
auto q_shape = xla::SpanToVector(GetXlaShape(q).dimensions());
xla::Shape softmax_lse_shape = xla::ShapeUtil::MakeShape(
xla::PrimitiveType::F32,
{q_shape[0], q_shape[2], q_shape[1]}); // batch_size, num_heads, seqlen_q(padding)
{q_shape[0], q_shape[2],
q_shape[1]}); // batch_size, num_heads, seqlen_q(padding)
xla::Shape rng_state_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::U64, {2});
xla::Shape cu_seqlens_shape =
Expand Down Expand Up @@ -311,11 +312,9 @@ FlashAttentionVarlenForward::FlashAttentionVarlenForward(
FlashAttentionVarlenForward::FlashAttentionVarlenForward(
const torch::lazy::Value& q, const torch::lazy::Value& k,
const torch::lazy::Value& v, const torch::lazy::Value& attention_mask,
const torch::lazy::Value& alibi_slopes,
const std::string params)
const torch::lazy::Value& alibi_slopes, const std::string params)
: XlaNode(xla_flash_attention_forward,
{q, k, v, attention_mask, alibi_slopes},
NodeOutputShape(q),
{q, k, v, attention_mask, alibi_slopes}, NodeOutputShape(q),
/*num_outputs=*/5, torch::lazy::MHash(params)),
params_(params) {}

Expand All @@ -327,7 +326,8 @@ torch::lazy::NodePtr FlashAttentionVarlenForward::Clone(
operands.at(4), params_);
} else {
torch::lazy::MakeNode<FlashAttentionVarlenForward>(
operands.at(0), operands.at(1), operands.at(2), operands.at(3), params_);
operands.at(0), operands.at(1), operands.at(2), operands.at(3),
params_);
}
}

Expand Down
6 changes: 2 additions & 4 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,7 @@ std::vector<XLATensorPtr> flash_attention_backward(
const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
const XLATensorPtr& v, const XLATensorPtr& out,
const XLATensorPtr& softmax_lse, const XLATensorPtr& rng_state,
const XLATensorPtr& alibi_slopes,
const std::string& params) {
const XLATensorPtr& alibi_slopes, const std::string& params) {
if (alibi_slopes) {
torch::lazy::NodePtr node = torch::lazy::MakeNode<FlashAttentionBackward>(
dout->GetIrValue(), q->GetIrValue(), k->GetIrValue(), v->GetIrValue(),
Expand All @@ -689,8 +688,7 @@ std::vector<XLATensorPtr> flash_attention_varlen_backward(
const XLATensorPtr& v, const XLATensorPtr& out,
const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q,
const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state,
const XLATensorPtr& alibi_slopes,
const std::string& params) {
const XLATensorPtr& alibi_slopes, const std::string& params) {
if (alibi_slopes) {
torch::lazy::NodePtr node =
torch::lazy::MakeNode<FlashAttentionVarlenBackward>(
Expand Down
6 changes: 2 additions & 4 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,14 @@ std::vector<XLATensorPtr> flash_attention_backward(
const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
const XLATensorPtr& v, const XLATensorPtr& out,
const XLATensorPtr& softmax_lse, const XLATensorPtr& rng_state,
const XLATensorPtr& alibi_slopes,
const std::string& params);
const XLATensorPtr& alibi_slopes, const std::string& params);

std::vector<XLATensorPtr> flash_attention_varlen_backward(
const XLATensorPtr& dout, const XLATensorPtr& q, const XLATensorPtr& k,
const XLATensorPtr& v, const XLATensorPtr& out,
const XLATensorPtr& softmax_lse, const XLATensorPtr& cu_seqlens_q,
const XLATensorPtr& cu_seqlens_k, const XLATensorPtr& rng_state,
const XLATensorPtr& alibi_slopes,
const std::string& params);
const XLATensorPtr& alibi_slopes, const std::string& params);

std::vector<XLATensorPtr> user_computation(
const std::string& opname, absl::Span<const XLATensorPtr> inputs,
Expand Down

0 comments on commit beb8ee4

Please sign in to comment.