Skip to content

Commit

Permalink
Merge pull request #1 from mira-6/mira-6-patch-1
Browse files Browse the repository at this point in the history
Update attention_couple.py
  • Loading branch information
mira-6 authored Jul 25, 2024
2 parents d7183d3 + 85db11b commit 4f164bd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion attention_couple.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def patch(q, k, v, extra_options):
q_target = q_list[i].repeat(length, 1, 1)
k = torch.cat([k[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0)
v = torch.cat([v[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0)


if k.dtype != q_target.dtype or v.dtype != q_target.dtype:
# Ensure all dtypes match
k = k.to(q_target.dtype)
v = v.to(q_target.dtype)
qkv = optimized_attention(q_target, k, v, extra_options["n_heads"])
qkv = qkv * masks
qkv = qkv.view(length, b, -1, module.heads * module.dim_head).sum(dim=0)
Expand Down

0 comments on commit 4f164bd

Please sign in to comment.