diff --git a/lvdm/modules/attention.py b/lvdm/modules/attention.py index 3d3140c..bceba7d 100644 --- a/lvdm/modules/attention.py +++ b/lvdm/modules/attention.py @@ -120,7 +120,7 @@ def forward(self, x, context=None, mask=None): del k_ip sim_ip = sim_ip.softmax(dim=-1) out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) - out_ip = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) out = out + self.image_cross_attention_scale * out_ip del q