mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for multi-image IP-Adapter.
This commit is contained in:
@ -67,6 +67,12 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Apply IP-Adapter attention.
|
||||
|
||||
Args:
|
||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
@ -127,26 +133,35 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The channel dimensions should match.
|
||||
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# ip_hidden_state.shape = (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# ip_key.shape, ip_value.shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
|
||||
# ip_key.shape, ip_value.shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# ip_hidden_states.shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# ip_hidden_states.shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
|
@ -55,11 +55,11 @@ class PostprocessingSettings:
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
"""IP-Adapter image encoder conditioning embeddings.
|
||||
Shape: (batch_size, num_tokens, encoding_dim).
|
||||
Shape: (num_images, num_tokens, encoding_dim).
|
||||
"""
|
||||
uncond_image_prompt_embeds: torch.Tensor
|
||||
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
||||
Shape: (batch_size, num_tokens, encoding_dim).
|
||||
Shape: (num_images, num_tokens, encoding_dim).
|
||||
"""
|
||||
|
||||
|
||||
|
@ -345,9 +345,12 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||
torch.stack(
|
||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||
)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
@ -415,9 +418,10 @@ class InvokeAIDiffuserComponent:
|
||||
# Run unconditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
ipa_conditioning.uncond_image_prompt_embeds
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
@ -444,9 +448,10 @@ class InvokeAIDiffuserComponent:
|
||||
# Run conditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
ipa_conditioning.cond_image_prompt_embeds
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
|
Reference in New Issue
Block a user