diff --git a/attention_couple.py b/attention_couple.py index af3067d..5289605 100644 --- a/attention_couple.py +++ b/attention_couple.py @@ -140,7 +140,11 @@ def patch(q, k, v, extra_options): q_target = q_list[i].repeat(length, 1, 1) k = torch.cat([k[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0) v = torch.cat([v[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0) - + + if k.dtype != q_target.dtype or v.dtype != q_target.dtype: + # Ensure all dtypes match + k = k.to(q_target.dtype) + v = v.to(q_target.dtype) qkv = optimized_attention(q_target, k, v, extra_options["n_heads"]) qkv = qkv * masks qkv = qkv.view(length, b, -1, module.heads * module.dim_head).sum(dim=0)