From 85db11b58844a94248e565063fc50c5b1a40a879 Mon Sep 17 00:00:00 2001 From: Fox Rayside <134891345+mira-6@users.noreply.github.com> Date: Thu, 25 Jul 2024 12:45:38 -0400 Subject: [PATCH] Update attention_couple.py Ensure all dtypes match. --- attention_couple.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/attention_couple.py b/attention_couple.py index af3067d..5289605 100644 --- a/attention_couple.py +++ b/attention_couple.py @@ -140,7 +140,11 @@ def patch(q, k, v, extra_options): q_target = q_list[i].repeat(length, 1, 1) k = torch.cat([k[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0) v = torch.cat([v[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0) - + + if k.dtype != q_target.dtype or v.dtype != q_target.dtype: + # Ensure all dtypes match + k = k.to(q_target.dtype) + v = v.to(q_target.dtype) qkv = optimized_attention(q_target, k, v, extra_options["n_heads"]) qkv = qkv * masks qkv = qkv.view(length, b, -1, module.heads * module.dim_head).sum(dim=0)