mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update CustomAttention to support both IP-Adapters and regional prompting.
This commit is contained in:
parent
2ce79b61f5
commit
31c456c1e6
@ -5,9 +5,44 @@ import torch.nn.functional as F
|
|||||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||||
from diffusers.utils import USE_PEFT_BACKEND
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
|
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||||
|
This implementation is based on
|
||||||
|
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||||
|
Supported custom features:
|
||||||
|
- IP-Adapter
|
||||||
|
- Regional prompt attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||||
|
ip_adapter_scales: Optional[list[float]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
|
layer-specific are passed to __init__().
|
||||||
|
Args:
|
||||||
|
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||||
|
for the i'th IP-Adapter.
|
||||||
|
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
||||||
|
the i'th IP-Adapter.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._ip_adapter_weights = ip_adapter_weights
|
||||||
|
self._ip_adapter_scales = ip_adapter_scales
|
||||||
|
|
||||||
|
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
||||||
|
if self._ip_adapter_weights is not None:
|
||||||
|
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
||||||
|
|
||||||
|
def _is_ip_adapter_enabled(self) -> bool:
|
||||||
|
return self._ip_adapter_weights is not None
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -17,7 +52,25 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
|
# For regional prompting:
|
||||||
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
|
percent_through: Optional[torch.FloatTensor] = None,
|
||||||
|
# For IP-Adapter:
|
||||||
|
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
"""Apply attention.
|
||||||
|
Args:
|
||||||
|
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||||
|
apply regional prompt masking.
|
||||||
|
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
|
||||||
|
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
|
||||||
|
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||||
|
"""
|
||||||
|
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
@ -31,7 +84,25 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
batch_size, sequence_length, _ = (
|
batch_size, sequence_length, _ = (
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
)
|
)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Handle regional prompt attention masks.
|
||||||
|
if regional_prompt_data is not None:
|
||||||
|
assert percent_through is not None
|
||||||
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
|
if is_cross_attention:
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = prompt_region_attention_mask
|
||||||
|
else:
|
||||||
|
attention_mask = prompt_region_attention_mask + attention_mask
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
@ -68,7 +139,54 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Apply IP-Adapter conditioning.
|
||||||
|
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||||
|
if self._is_ip_adapter_enabled():
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
for ipa_embed, ipa_weights, scale in zip(
|
||||||
|
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||||
|
):
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The token_len dimensions should match.
|
||||||
|
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||||
|
|
||||||
|
ip_hidden_states = ipa_embed
|
||||||
|
|
||||||
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
else:
|
||||||
|
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
|
||||||
|
assert ip_adapter_image_prompt_embeds is None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||||
# dropout
|
# dropout
|
||||||
|
Loading…
Reference in New Issue
Block a user