mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Specify the image_embedding_len in the IPAttnProcessor rather than the text embedding length. This enables the IPAttnProcessor to handle text embeddings of varying lengths.
This commit is contained in:
parent
7703bf2ca1
commit
c2d43f007b
@ -32,20 +32,21 @@ class IPAttnProcessor(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
hidden_size (`int`):
|
hidden_size (`int`):
|
||||||
The hidden size of the attention layer.
|
The hidden size of the attention layer.
|
||||||
|
image_embedding_len (`int`):
|
||||||
|
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
||||||
|
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
||||||
cross_attention_dim (`int`):
|
cross_attention_dim (`int`):
|
||||||
The number of channels in the `encoder_hidden_states`.
|
The number of channels in the `encoder_hidden_states`.
|
||||||
text_context_len (`int`, defaults to 77):
|
|
||||||
The context length of the text features.
|
|
||||||
scale (`float`, defaults to 1.0):
|
scale (`float`, defaults to 1.0):
|
||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
self.text_context_len = text_context_len
|
self.image_embedding_len = image_embedding_len
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
@ -85,10 +86,10 @@ class IPAttnProcessor(nn.Module):
|
|||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# split hidden states
|
# Split text encoder hidden states and image encoder hidden state.
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
encoder_hidden_states, ip_hidden_states = (
|
||||||
encoder_hidden_states[:, : self.text_context_len, :],
|
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
||||||
encoder_hidden_states[:, self.text_context_len :, :],
|
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
||||||
)
|
)
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
@ -137,15 +138,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
hidden_size (`int`):
|
hidden_size (`int`):
|
||||||
The hidden size of the attention layer.
|
The hidden size of the attention layer.
|
||||||
|
image_embedding_len (`int`):
|
||||||
|
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
|
||||||
|
the `encoder_hidden_states` are the IP-Adapter image embeddings.
|
||||||
cross_attention_dim (`int`):
|
cross_attention_dim (`int`):
|
||||||
The number of channels in the `encoder_hidden_states`.
|
The number of channels in the `encoder_hidden_states`.
|
||||||
text_context_len (`int`, defaults to 77):
|
|
||||||
The context length of the text features.
|
|
||||||
scale (`float`, defaults to 1.0):
|
scale (`float`, defaults to 1.0):
|
||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
@ -198,10 +200,10 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# split hidden states
|
# Split text encoder hidden states and image encoder hidden state.
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
encoder_hidden_states, ip_hidden_states = (
|
||||||
encoder_hidden_states[:, : self.text_context_len, :],
|
encoder_hidden_states[:, : -self.image_embedding_len, :],
|
||||||
encoder_hidden_states[:, self.text_context_len :, :],
|
encoder_hidden_states[:, -self.image_embedding_len :, :],
|
||||||
)
|
)
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
@ -91,7 +91,10 @@ class IPAdapter:
|
|||||||
else:
|
else:
|
||||||
print("swapping in IPAttnProcessor for", name)
|
print("swapping in IPAttnProcessor for", name)
|
||||||
attn_procs[name] = IPAttnProcessor(
|
attn_procs[name] = IPAttnProcessor(
|
||||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
hidden_size=hidden_size,
|
||||||
|
image_embedding_len=self.num_tokens,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
scale=1.0,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self.device, dtype=torch.float16)
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
||||||
|
Loading…
Reference in New Issue
Block a user