diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 37f0ebfa1d..2678816285 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -584,7 +584,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # print(f"SwapCrossAttnContext for {attention_type} active") batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask( + attention_mask=attention_mask, target_length=sequence_length, + batch_size=batch_size) query = attn.to_q(hidden_states) dim = query.shape[-1]