mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get multi-prompt attention working simultaneously with IP-adapter.
This commit is contained in:
@ -1,185 +0,0 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
|
||||||
# and modified as needed
|
|
||||||
|
|
||||||
# tencent-ailab comment:
|
|
||||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
|
||||||
|
|
||||||
|
|
||||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
|
||||||
# loading.
|
|
||||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
DiffusersAttnProcessor2_0.__init__(self)
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
|
||||||
ip_adapter_image_prompt_embeds parameter.
|
|
||||||
"""
|
|
||||||
return DiffusersAttnProcessor2_0.__call__(
|
|
||||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor2_0(torch.nn.Module):
|
|
||||||
r"""
|
|
||||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
scale (`float`, defaults to 1.0):
|
|
||||||
the weight scale of image prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
assert len(weights) == len(scales)
|
|
||||||
|
|
||||||
self._weights = weights
|
|
||||||
self._scales = scales
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Apply IP-Adapter attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
|
||||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
|
||||||
"""
|
|
||||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
|
||||||
is_cross_attention = encoder_hidden_states is not None
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
if is_cross_attention:
|
|
||||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
|
||||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
|
||||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
|
||||||
):
|
|
||||||
# The batch dimensions should match.
|
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
|
||||||
# The token_len dimensions should match.
|
|
||||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
|
||||||
|
|
||||||
ip_hidden_states = ipa_embed
|
|
||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
@ -23,13 +23,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
|
|
||||||
@ -427,11 +426,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
||||||
)
|
)
|
||||||
if use_ip_adapter and use_regional_prompting:
|
|
||||||
# TODO(ryand): Implement this.
|
|
||||||
raise NotImplementedError("Coming soon.")
|
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
unet_attention_patcher = None
|
||||||
self.use_ip_adapter = use_ip_adapter
|
self.use_ip_adapter = use_ip_adapter
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
if use_cross_attention_control:
|
if use_cross_attention_control:
|
||||||
@ -439,11 +435,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
)
|
)
|
||||||
if use_ip_adapter:
|
if use_ip_adapter or use_regional_prompting:
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
if use_regional_prompting:
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model)
|
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@ -471,7 +466,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
unet_attention_patcher=unet_attention_patcher,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
@ -503,7 +498,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -526,10 +521,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
unet_attention_patcher.set_scale(i, weight)
|
||||||
else:
|
else:
|
||||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
unet_attention_patcher.set_scale(i, 0.0)
|
||||||
|
|
||||||
# Handle ControlNet(s)
|
# Handle ControlNet(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
|
@ -6,7 +6,7 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
|||||||
from diffusers.utils import USE_PEFT_BACKEND
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
@ -149,10 +149,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
# Apply IP-Adapter conditioning.
|
# Apply IP-Adapter conditioning.
|
||||||
if is_cross_attention:
|
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||||
if self._is_ip_adapter_enabled():
|
if self._is_ip_adapter_enabled():
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
for ipa_embed, ipa_weights, scale in zip(
|
||||||
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||||
):
|
):
|
@ -1,208 +0,0 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
|
||||||
from diffusers.utils import USE_PEFT_BACKEND
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|
||||||
TextConditioningRegions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
|
||||||
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
|
|
||||||
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_regions(
|
|
||||||
cls,
|
|
||||||
regions: list[TextConditioningRegions],
|
|
||||||
key_seq_len: int,
|
|
||||||
# TODO(ryand): Pass in a list of downscale factors?
|
|
||||||
max_downscale_factor: int = 8,
|
|
||||||
):
|
|
||||||
"""Construct a `RegionalPromptData` object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
|
||||||
batch.
|
|
||||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
|
||||||
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
|
||||||
explicitly to be sure.
|
|
||||||
"""
|
|
||||||
attn_masks_by_seq_len = {}
|
|
||||||
|
|
||||||
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
|
||||||
# length of s.
|
|
||||||
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
|
||||||
for batch_sample_regions in regions:
|
|
||||||
batch_attn_masks_by_seq_len.append({})
|
|
||||||
|
|
||||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
|
||||||
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
|
||||||
downscale_factor = 1
|
|
||||||
while downscale_factor <= max_downscale_factor:
|
|
||||||
_, num_prompts, h, w = batch_masks.shape
|
|
||||||
query_seq_len = h * w
|
|
||||||
|
|
||||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
|
||||||
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
|
||||||
|
|
||||||
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
|
||||||
# `encoder_hidden_states`.
|
|
||||||
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
|
||||||
# TODO(ryand): What device / dtype should this be?
|
|
||||||
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
|
||||||
|
|
||||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
|
||||||
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
|
||||||
:, prompt_idx, :, :
|
|
||||||
]
|
|
||||||
|
|
||||||
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
|
||||||
|
|
||||||
downscale_factor *= 2
|
|
||||||
if downscale_factor <= max_downscale_factor:
|
|
||||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
|
||||||
# regions to be lost entirely.
|
|
||||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
|
||||||
# potentially use a weighted mask rather than a binary mask.
|
|
||||||
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
|
||||||
|
|
||||||
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
|
||||||
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
|
||||||
attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
|
||||||
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(attn_masks_by_seq_len)
|
|
||||||
|
|
||||||
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
|
||||||
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
|
|
||||||
|
|
||||||
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
|
|
||||||
it changes at each downscaling level in the model.
|
|
||||||
|
|
||||||
key_seq_len is the length of the expected prompt embeddings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The masks.
|
|
||||||
shape: (batch_size, query_seq_len, key_seq_len).
|
|
||||||
dtype: float
|
|
||||||
The mask is a binary mask with values of 0.0 and 1.0.
|
|
||||||
"""
|
|
||||||
return self._attn_masks_by_seq_len[query_seq_len]
|
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
|
||||||
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states: torch.FloatTensor,
|
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
|
||||||
scale: float = 1.0,
|
|
||||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
residual = hidden_states
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if encoder_hidden_states is not None and regional_prompt_data is not None:
|
|
||||||
# If encoder_hidden_states is not None, that means we are doing cross-attention case.
|
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
|
||||||
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
|
|
||||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
|
||||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
|
||||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
|
|
||||||
)
|
|
||||||
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
|
||||||
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
|
||||||
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = prompt_region_attention_mask
|
|
||||||
else:
|
|
||||||
attention_mask = prompt_region_attention_mask + attention_mask
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
args = () if USE_PEFT_BACKEND else (scale,)
|
|
||||||
query = attn.to_q(hidden_states, *args)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states, *args)
|
|
||||||
value = attn.to_v(encoder_hidden_states, *args)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def apply_regional_prompt_attn(unet: UNet2DConditionModel):
|
|
||||||
"""A context manager that patches `unet` with RegionalPromptAttnProcessor2_0 attention processors."""
|
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
|
||||||
|
|
||||||
try:
|
|
||||||
unet.set_attn_processor(RegionalPromptAttnProcessor2_0())
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
unet.set_attn_processor(orig_attn_processors)
|
|
@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegionalPromptData:
|
||||||
|
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
|
||||||
|
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_regions(
|
||||||
|
cls,
|
||||||
|
regions: list[TextConditioningRegions],
|
||||||
|
key_seq_len: int,
|
||||||
|
# TODO(ryand): Pass in a list of downscale factors?
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
|
):
|
||||||
|
"""Construct a `RegionalPromptData` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
|
batch.
|
||||||
|
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||||
|
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
||||||
|
explicitly to be sure.
|
||||||
|
"""
|
||||||
|
attn_masks_by_seq_len = {}
|
||||||
|
|
||||||
|
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
||||||
|
# length of s.
|
||||||
|
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
|
for batch_sample_regions in regions:
|
||||||
|
batch_attn_masks_by_seq_len.append({})
|
||||||
|
|
||||||
|
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
|
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
_, num_prompts, h, w = batch_masks.shape
|
||||||
|
query_seq_len = h * w
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
||||||
|
|
||||||
|
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
||||||
|
# `encoder_hidden_states`.
|
||||||
|
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
||||||
|
# TODO(ryand): What device / dtype should this be?
|
||||||
|
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
||||||
|
|
||||||
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
|
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
||||||
|
:, prompt_idx, :, :
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
||||||
|
|
||||||
|
downscale_factor *= 2
|
||||||
|
if downscale_factor <= max_downscale_factor:
|
||||||
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||||
|
# regions to be lost entirely.
|
||||||
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||||
|
# potentially use a weighted mask rather than a binary mask.
|
||||||
|
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
||||||
|
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
||||||
|
attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
||||||
|
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(attn_masks_by_seq_len)
|
||||||
|
|
||||||
|
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
|
||||||
|
|
||||||
|
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
|
||||||
|
it changes at each downscaling level in the model.
|
||||||
|
|
||||||
|
key_seq_len is the length of the expected prompt embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masks.
|
||||||
|
shape: (batch_size, query_seq_len, key_seq_len).
|
||||||
|
dtype: float
|
||||||
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
return self._attn_masks_by_seq_len[query_seq_len]
|
@ -16,7 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
TextConditioningRegions,
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@ -303,19 +303,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
|
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||||
torch.stack(
|
|
||||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
|
||||||
)
|
|
||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
uncond_text = conditioning_data.uncond_text
|
uncond_text = conditioning_data.uncond_text
|
||||||
cond_text = conditioning_data.cond_text
|
cond_text = conditioning_data.cond_text
|
||||||
@ -352,9 +346,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
regions.append(r)
|
regions.append(r)
|
||||||
|
|
||||||
_, key_seq_len, _ = both_conditionings.shape
|
_, key_seq_len, _ = both_conditionings.shape
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||||
"regional_prompt_data": RegionalPromptData.from_regions(regions=regions, key_seq_len=key_seq_len)
|
regions=regions, key_seq_len=key_seq_len
|
||||||
}
|
)
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
@ -424,21 +418,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Unconditioned pass
|
# Unconditioned pass
|
||||||
#####################
|
#####################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
|
||||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
@ -451,11 +443,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Prepare prompt regions for the unconditioned pass.
|
# Prepare prompt regions for the unconditioned pass.
|
||||||
if conditioning_data.uncond_regions is not None:
|
if conditioning_data.uncond_regions is not None:
|
||||||
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
|
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||||
"regional_prompt_data": RegionalPromptData.from_regions(
|
|
||||||
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
|
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
@ -473,22 +463,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Conditioned pass
|
# Conditioned pass
|
||||||
###################
|
###################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
|
||||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
@ -501,11 +489,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Prepare prompt regions for the conditioned pass.
|
# Prepare prompt regions for the conditioned pass.
|
||||||
if conditioning_data.cond_regions is not None:
|
if conditioning_data.cond_regions is not None:
|
||||||
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
|
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||||
"regional_prompt_data": RegionalPromptData.from_regions(
|
|
||||||
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
|
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
|
@ -1,52 +1,55 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
|
||||||
|
|
||||||
|
|
||||||
class UNetPatcher:
|
class UNetAttentionPatcher:
|
||||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapters
|
||||||
self._scales = [1.0] * len(self._ip_adapters)
|
self._ip_adapter_scales = None
|
||||||
|
|
||||||
|
if self._ip_adapters is not None:
|
||||||
|
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||||
|
|
||||||
def set_scale(self, idx: int, value: float):
|
def set_scale(self, idx: int, value: float):
|
||||||
self._scales[idx] = value
|
self._ip_adapter_scales[idx] = value
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
weights into them.
|
weights into them (if IP-Adapters are being applied).
|
||||||
|
|
||||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||||
"""
|
"""
|
||||||
# Construct a dict of attention processors based on the UNet's architecture.
|
# Construct a dict of attention processors based on the UNet's architecture.
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||||
if name.endswith("attn1.processor"):
|
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||||
attn_procs[name] = AttnProcessor2_0()
|
# "attn1" processors do not use IP-Adapters.
|
||||||
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = IPAttnProcessor2_0(
|
attn_procs[name] = CustomAttnProcessor2_0(
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||||
self._scales,
|
self._ip_adapter_scales,
|
||||||
)
|
)
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
attn_procs = self._prepare_attention_processors(unet)
|
attn_procs = self._prepare_attention_processors(unet)
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
orig_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
from invokeai.backend.util.test_utils import install_and_load_model
|
from invokeai.backend.util.test_utils import install_and_load_model
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
|||||||
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
||||||
|
|
||||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
|
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
|
||||||
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
||||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user