Fix a minor bug in the logic of the IPAttnProcessor2_0. The change won't have any functional effect, since this attention implementation was only being used for cross-attention, but the logic should be correct now in case we wanted to use it for self-attention.

This commit is contained in:
Ryan Dick 2024-02-16 09:10:47 -05:00
parent ba4788007f
commit 38248b988f

View File

@ -73,6 +73,9 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
# If true, we are doing cross-attention, if false we are doing self-attention.
is_cross_attention = encoder_hidden_states is not None
residual = hidden_states
if attn.spatial_norm is not None:
@ -124,7 +127,7 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
if is_cross_attention:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None