Fix crash on calling diffusers' prepare_attention_mask (#2743)

Diffusers' `prepare_attention_mask` was crashing when we didn't pass in
a batch size.
This commit is contained in:
Lincoln Stein 2023-02-20 12:35:33 -05:00 committed by GitHub
commit aab8263c31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -584,7 +584,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
attention_mask = attn.prepare_attention_mask(
attention_mask=attention_mask, target_length=sequence_length,
batch_size=batch_size)
query = attn.to_q(hidden_states)
dim = query.shape[-1]