Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts).

This commit is contained in:
Ryan Dick
2023-09-08 11:47:36 -04:00
parent ddc148b70b
commit b2d5b53b5f
5 changed files with 135 additions and 68 deletions

View File

@ -19,12 +19,42 @@ class AttnProcessor(DiffusersAttnProcessor, nn.Module):
DiffusersAttnProcessor.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor(nn.Module):
r"""
@ -32,21 +62,17 @@ class IPAttnProcessor(nn.Module):
Args:
hidden_size (`int`):
The hidden size of the attention layer.
image_embedding_len (`int`):
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
the `encoder_hidden_states` are the IP-Adapter image embeddings.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.image_embedding_len = image_embedding_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -59,7 +85,18 @@ class IPAttnProcessor(nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -86,12 +123,6 @@ class IPAttnProcessor(nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# Split text encoder hidden states and image encoder hidden state.
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : -self.image_embedding_len, :],
encoder_hidden_states[:, -self.image_embedding_len :, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -103,18 +134,18 @@ class IPAttnProcessor(nn.Module):
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -138,16 +169,13 @@ class IPAttnProcessor2_0(torch.nn.Module):
Args:
hidden_size (`int`):
The hidden size of the attention layer.
image_embedding_len (`int`):
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
the `encoder_hidden_states` are the IP-Adapter image embeddings.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@ -155,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.text_context_len = text_context_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -168,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -200,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# Split text encoder hidden states and image encoder hidden state.
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : -self.image_embedding_len, :],
encoder_hidden_states[:, -self.image_embedding_len :, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -226,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@ -92,7 +92,6 @@ class IPAdapter:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
image_embedding_len=self.num_tokens,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=torch.float16)