From b21bd6f428673448309bf3f096cf884fa713c37f Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Mon, 20 Feb 2023 11:12:47 -0600 Subject: [PATCH] Fix crash on calling diffusers' prepare_attention_mask Diffusers' `prepare_attention_mask` was crashing when we didn't pass in a batch size. --- ldm/models/diffusion/cross_attention_control.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 37f0ebfa1d..2678816285 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -584,7 +584,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # print(f"SwapCrossAttnContext for {attention_type} active") batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) + attention_mask = attn.prepare_attention_mask( + attention_mask=attention_mask, target_length=sequence_length, + batch_size=batch_size) query = attn.to_q(hidden_states) dim = query.shape[-1]