From d58688013e007c3a782c405ef837087c418d52b7 Mon Sep 17 00:00:00 2001 From: scutpaul Date: Mon, 6 Nov 2023 15:24:28 +0800 Subject: [PATCH] fix bug for issue#38 --- lvdm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lvdm/modules/attention.py b/lvdm/modules/attention.py index 3d3140c..bceba7d 100644 --- a/lvdm/modules/attention.py +++ b/lvdm/modules/attention.py @@ -120,7 +120,7 @@ def forward(self, x, context=None, mask=None): del k_ip sim_ip = sim_ip.softmax(dim=-1) out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) - out_ip = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) out = out + self.image_cross_attention_scale * out_ip del q