mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Specify the image_embedding_len in the IPAttnProcessor rather than the text embedding length. This enables the IPAttnProcessor to handle text embeddings of varying lengths.
This commit is contained in:
parent
7703bf2ca1
commit
c2d43f007b
@ -32,20 +32,21 @@ class IPAttnProcessor(nn.Module):
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
image_embedding_len (`int`):
|
||||
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
||||
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
text_context_len (`int`, defaults to 77):
|
||||
The context length of the text features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.text_context_len = text_context_len
|
||||
self.image_embedding_len = image_embedding_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
@ -85,10 +86,10 @@ class IPAttnProcessor(nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
# Split text encoder hidden states and image encoder hidden state.
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, : self.text_context_len, :],
|
||||
encoder_hidden_states[:, self.text_context_len :, :],
|
||||
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
||||
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
@ -137,15 +138,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
image_embedding_len (`int`):
|
||||
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
||||
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
text_context_len (`int`, defaults to 77):
|
||||
The context length of the text features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@ -198,10 +200,10 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
# Split text encoder hidden states and image encoder hidden state.
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, : self.text_context_len, :],
|
||||
encoder_hidden_states[:, self.text_context_len :, :],
|
||||
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
||||
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
|
@ -91,7 +91,10 @@ class IPAdapter:
|
||||
else:
|
||||
print("swapping in IPAttnProcessor for", name)
|
||||
attn_procs[name] = IPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
hidden_size=hidden_size,
|
||||
image_embedding_len=self.num_tokens,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
||||
|
Loading…
Reference in New Issue
Block a user