mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts).
This commit is contained in:
@ -19,12 +19,42 @@ class AttnProcessor(DiffusersAttnProcessor, nn.Module):
|
||||
DiffusersAttnProcessor.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
||||
|
||||
|
||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor2_0.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor2_0.__call__(
|
||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||
)
|
||||
|
||||
|
||||
class IPAttnProcessor(nn.Module):
|
||||
r"""
|
||||
@ -32,21 +62,17 @@ class IPAttnProcessor(nn.Module):
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
image_embedding_len (`int`):
|
||||
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
||||
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.image_embedding_len = image_embedding_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
@ -59,7 +85,18 @@ class IPAttnProcessor(nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
# The batch dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The channel dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
@ -86,12 +123,6 @@ class IPAttnProcessor(nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# Split text encoder hidden states and image encoder hidden state.
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
||||
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
@ -103,18 +134,18 @@ class IPAttnProcessor(nn.Module):
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
if ip_hidden_states is not None:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@ -138,16 +169,13 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
image_embedding_len (`int`):
|
||||
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
||||
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@ -155,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.text_context_len = text_context_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
@ -168,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
# The batch dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The channel dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
@ -200,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# Split text encoder hidden states and image encoder hidden state.
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
||||
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
@ -226,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
if ip_hidden_states:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
@ -92,7 +92,6 @@ class IPAdapter:
|
||||
print("swapping in IPAttnProcessor for", name)
|
||||
attn_procs[name] = IPAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
image_embedding_len=self.num_tokens,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
|
Reference in New Issue
Block a user