Skip to content

Commit

Permalink
GH-7 Revert attempt to fix prompt size difference failure
Browse files Browse the repository at this point in the history
  • Loading branch information
Danand committed Jul 16, 2024
1 parent e2f9b11 commit cc40230
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions attention_couple.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def patch(q, k, v, extra_options):
masks_uncond = get_masks_from_q(self.negative_positive_masks[0], q_list[0], extra_options["original_shape"])
masks_cond = get_masks_from_q(self.negative_positive_masks[1], q_list[0], extra_options["original_shape"])

context_uncond = self.negative_positive_conds[0][0]
context_cond = self.negative_positive_conds[1][0]
context_uncond = torch.cat([cond for cond in self.negative_positive_conds[0]], dim=0)
context_cond = torch.cat([cond for cond in self.negative_positive_conds[1]], dim=0)

k_uncond = module.to_k(context_uncond)
k_cond = module.to_k(context_cond)
Expand All @@ -138,6 +138,8 @@ def patch(q, k, v, extra_options):
length = len_neg

q_target = q_list[i].repeat(length, 1, 1)
k = torch.cat([k[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0)
v = torch.cat([v[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0)

qkv = optimized_attention(q_target, k, v, extra_options["n_heads"])
qkv = qkv * masks
Expand Down

0 comments on commit cc40230

Please sign in to comment.