Specify the image_embedding_len in the IPAttnProcessor rather than the text embedding length. This enables the IPAttnProcessor to handle text embeddings of varying lengths.

This commit is contained in:
Ryan Dick 2023-09-07 18:20:21 -04:00
parent 7703bf2ca1
commit c2d43f007b
2 changed files with 19 additions and 14 deletions

View File

@ -32,20 +32,21 @@ class IPAttnProcessor(nn.Module):
Args:
hidden_size (`int`):
The hidden size of the attention layer.
image_embedding_len (`int`):
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
the `encoder_hidden_states` are the IP-Adapter image embeddings.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.text_context_len = text_context_len
self.image_embedding_len = image_embedding_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -85,10 +86,10 @@ class IPAttnProcessor(nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
# Split text encoder hidden states and image encoder hidden state.
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : self.text_context_len, :],
encoder_hidden_states[:, self.text_context_len :, :],
encoder_hidden_states[:, : -self.image_embedding_len, :],
encoder_hidden_states[:, -self.image_embedding_len :, :],
)
key = attn.to_k(encoder_hidden_states)
@ -137,15 +138,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
Args:
hidden_size (`int`):
The hidden size of the attention layer.
image_embedding_len (`int`):
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
the `encoder_hidden_states` are the IP-Adapter image embeddings.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@ -198,10 +200,10 @@ class IPAttnProcessor2_0(torch.nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
# Split text encoder hidden states and image encoder hidden state.
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : self.text_context_len, :],
encoder_hidden_states[:, self.text_context_len :, :],
encoder_hidden_states[:, : -self.image_embedding_len, :],
encoder_hidden_states[:, -self.image_embedding_len :, :],
)
key = attn.to_k(encoder_hidden_states)

View File

@ -91,7 +91,10 @@ class IPAdapter:
else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
hidden_size=hidden_size,
image_embedding_len=self.num_tokens,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))