From 8989a6cdc61f90eb6a8462aaada0b9f767fe07c6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 29 Feb 2024 14:54:13 -0500 Subject: [PATCH] Get multi-prompt attention working simultaneously with IP-adapter. --- .../backend/ip_adapter/attention_processor.py | 185 ---------------- .../stable_diffusion/diffusers_pipeline.py | 25 +-- .../diffusion}/custom_attention.py | 5 +- .../diffusion/regional_prompt_attention.py | 208 ------------------ .../diffusion/regional_prompt_data.py | 93 ++++++++ .../diffusion/shared_invokeai_diffusion.py | 68 +++--- .../diffusion/unet_attention_patcher.py} | 37 ++-- tests/backend/ip_adapter/test_ip_adapter.py | 4 +- 8 files changed, 154 insertions(+), 471 deletions(-) delete mode 100644 invokeai/backend/ip_adapter/attention_processor.py rename invokeai/backend/{attention => stable_diffusion/diffusion}/custom_attention.py (99%) delete mode 100644 invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py create mode 100644 invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py rename invokeai/backend/{ip_adapter/unet_patcher.py => stable_diffusion/diffusion/unet_attention_patcher.py} (53%) diff --git a/invokeai/backend/ip_adapter/attention_processor.py b/invokeai/backend/ip_adapter/attention_processor.py deleted file mode 100644 index 767a366d0c..0000000000 --- a/invokeai/backend/ip_adapter/attention_processor.py +++ /dev/null @@ -1,185 +0,0 @@ -# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) -# and modified as needed - -# tencent-ailab comment: -# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0 - -from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights - - -# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict -# loading. -class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module): - def __init__(self): - DiffusersAttnProcessor2_0.__init__(self) - nn.Module.__init__(self) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the - ip_adapter_image_prompt_embeds parameter. - """ - return DiffusersAttnProcessor2_0.__call__( - self, attn, hidden_states, encoder_hidden_states, attention_mask, temb - ) - - -class IPAttnProcessor2_0(torch.nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - """ - - def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - assert len(weights) == len(scales) - - self._weights = weights - self._scales = scales - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ip_adapter_image_prompt_embeds=None, - ): - """Apply IP-Adapter attention. - - Args: - ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings. - Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len). - """ - # If true, we are doing cross-attention, if false we are doing self-attention. - is_cross_attention = encoder_hidden_states is not None - - residual = hidden_states - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if is_cross_attention: - # If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case, - # we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here. - assert ip_adapter_image_prompt_embeds is not None - assert len(ip_adapter_image_prompt_embeds) == len(self._weights) - - for ipa_embed, ipa_weights, scale in zip( - ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True - ): - # The batch dimensions should match. - assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] - # The token_len dimensions should match. - assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1] - - ip_hidden_states = ipa_embed - - # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) - - # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - - hidden_states = hidden_states + scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index bbdea3466f..b13a1271eb 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -23,13 +23,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( IPAdapterConditioningInfo, TextConditioningData, ) -from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from ..util import auto_detect_slice_size, normalize_device @@ -427,11 +426,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): raise ValueError( "Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously." ) - if use_ip_adapter and use_regional_prompting: - # TODO(ryand): Implement this. - raise NotImplementedError("Coming soon.") - ip_adapter_unet_patcher = None + unet_attention_patcher = None self.use_ip_adapter = use_ip_adapter attn_ctx = nullcontext() if use_cross_attention_control: @@ -439,11 +435,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, ) - if use_ip_adapter: - ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) - attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - if use_regional_prompting: - attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model) + if use_ip_adapter or use_regional_prompting: + ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) with attn_ctx: if callback is not None: @@ -471,7 +466,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, - ip_adapter_unet_patcher=ip_adapter_unet_patcher, + unet_attention_patcher=unet_attention_patcher, ) latents = step_output.prev_sample predicted_original = getattr(step_output, "pred_original_sample", None) @@ -503,7 +498,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ip_adapter_unet_patcher: Optional[UNetPatcher] = None, + unet_attention_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -526,10 +521,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) if step_index >= first_adapter_step and step_index <= last_adapter_step: # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - ip_adapter_unet_patcher.set_scale(i, weight) + unet_attention_patcher.set_scale(i, weight) else: # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - ip_adapter_unet_patcher.set_scale(i, 0.0) + unet_attention_patcher.set_scale(i, 0.0) # Handle ControlNet(s) down_block_additional_residuals = None diff --git a/invokeai/backend/attention/custom_attention.py b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py similarity index 99% rename from invokeai/backend/attention/custom_attention.py rename to invokeai/backend/stable_diffusion/diffusion/custom_attention.py index 69c2eae5b6..58aba2f709 100644 --- a/invokeai/backend/attention/custom_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_attention.py @@ -6,7 +6,7 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.utils import USE_PEFT_BACKEND from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights -from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData class CustomAttnProcessor2_0(AttnProcessor2_0): @@ -149,10 +149,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # End unmodified block from AttnProcessor2_0. # Apply IP-Adapter conditioning. - if is_cross_attention: + if is_cross_attention and self._is_ip_adapter_enabled(): if self._is_ip_adapter_enabled(): assert ip_adapter_image_prompt_embeds is not None - for ipa_embed, ipa_weights, scale in zip( ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True ): diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py deleted file mode 100644 index 74d7c2755d..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ /dev/null @@ -1,208 +0,0 @@ -from contextlib import contextmanager -from typing import Optional - -import torch -import torch.nn.functional as F -from diffusers import UNet2DConditionModel -from diffusers.models.attention_processor import Attention, AttnProcessor2_0 -from diffusers.utils import USE_PEFT_BACKEND - -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - TextConditioningRegions, -) - - -class RegionalPromptData: - def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]): - self._attn_masks_by_seq_len = attn_masks_by_seq_len - - @classmethod - def from_regions( - cls, - regions: list[TextConditioningRegions], - key_seq_len: int, - # TODO(ryand): Pass in a list of downscale factors? - max_downscale_factor: int = 8, - ): - """Construct a `RegionalPromptData` object. - - Args: - regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the - batch. - key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the - cross-attention layers). This is most likely equal to the max embedding range end, but we pass it - explicitly to be sure. - """ - attn_masks_by_seq_len = {} - - # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence - # length of s. - batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] - for batch_sample_regions in regions: - batch_attn_masks_by_seq_len.append({}) - - # Convert the bool masks to float masks so that max pooling can be applied. - batch_masks = batch_sample_regions.masks.to(dtype=torch.float32) - - # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. - downscale_factor = 1 - while downscale_factor <= max_downscale_factor: - _, num_prompts, h, w = batch_masks.shape - query_seq_len = h * w - - # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). - batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1)) - - # Create a cross-attention mask for each prompt that selects the corresponding embeddings from - # `encoder_hidden_states`. - # attn_mask shape: (batch_size, query_seq_len, key_seq_len) - # TODO(ryand): What device / dtype should this be? - attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) - - for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): - attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ - :, prompt_idx, :, : - ] - - batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask - - downscale_factor *= 2 - if downscale_factor <= max_downscale_factor: - # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt - # regions to be lost entirely. - # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could - # potentially use a weighted mask rather than a binary mask. - batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2) - - # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. - for query_seq_len in batch_attn_masks_by_seq_len[0].keys(): - attn_masks_by_seq_len[query_seq_len] = torch.cat( - [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))] - ) - - return cls(attn_masks_by_seq_len) - - def get_attn_mask(self, query_seq_len: int) -> torch.Tensor: - """Get the attention mask for the given query sequence length (i.e. downscaling level). - - This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so - it changes at each downscaling level in the model. - - key_seq_len is the length of the expected prompt embeddings. - - Returns: - torch.Tensor: The masks. - shape: (batch_size, query_seq_len, key_seq_len). - dtype: float - The mask is a binary mask with values of 0.0 and 1.0. - """ - return self._attn_masks_by_seq_len[query_seq_len] - - -class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): - """An attention processor that supports regional prompt attention for PyTorch 2.0.""" - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - scale: float = 1.0, - regional_prompt_data: Optional[RegionalPromptData] = None, - ) -> torch.FloatTensor: - residual = hidden_states - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if encoder_hidden_states is not None and regional_prompt_data is not None: - # If encoder_hidden_states is not None, that means we are doing cross-attention case. - _, query_seq_len, _ = hidden_states.shape - prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len) - # TODO(ryand): Avoid redundant type/device conversion here. - prompt_region_attention_mask = prompt_region_attention_mask.to( - dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device - ) - prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0 - prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 - - if attention_mask is None: - attention_mask = prompt_region_attention_mask - else: - attention_mask = prompt_region_attention_mask + attention_mask - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - args = () if USE_PEFT_BACKEND else (scale,) - query = attn.to_q(hidden_states, *args) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states, *args) - value = attn.to_v(encoder_hidden_states, *args) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states, *args) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -@contextmanager -def apply_regional_prompt_attn(unet: UNet2DConditionModel): - """A context manager that patches `unet` with RegionalPromptAttnProcessor2_0 attention processors.""" - - orig_attn_processors = unet.attn_processors - - try: - unet.set_attn_processor(RegionalPromptAttnProcessor2_0()) - yield None - finally: - unet.set_attn_processor(orig_attn_processors) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py new file mode 100644 index 0000000000..48b558d054 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, +) + + +class RegionalPromptData: + def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]): + self._attn_masks_by_seq_len = attn_masks_by_seq_len + + @classmethod + def from_regions( + cls, + regions: list[TextConditioningRegions], + key_seq_len: int, + # TODO(ryand): Pass in a list of downscale factors? + max_downscale_factor: int = 8, + ): + """Construct a `RegionalPromptData` object. + + Args: + regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the + batch. + key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the + cross-attention layers). This is most likely equal to the max embedding range end, but we pass it + explicitly to be sure. + """ + attn_masks_by_seq_len = {} + + # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence + # length of s. + batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] + for batch_sample_regions in regions: + batch_attn_masks_by_seq_len.append({}) + + # Convert the bool masks to float masks so that max pooling can be applied. + batch_masks = batch_sample_regions.masks.to(dtype=torch.float32) + + # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. + downscale_factor = 1 + while downscale_factor <= max_downscale_factor: + _, num_prompts, h, w = batch_masks.shape + query_seq_len = h * w + + # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1). + batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1)) + + # Create a cross-attention mask for each prompt that selects the corresponding embeddings from + # `encoder_hidden_states`. + # attn_mask shape: (batch_size, query_seq_len, key_seq_len) + # TODO(ryand): What device / dtype should this be? + attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) + + for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges): + attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ + :, prompt_idx, :, : + ] + + batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask + + downscale_factor *= 2 + if downscale_factor <= max_downscale_factor: + # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt + # regions to be lost entirely. + # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could + # potentially use a weighted mask rather than a binary mask. + batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2) + + # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len. + for query_seq_len in batch_attn_masks_by_seq_len[0].keys(): + attn_masks_by_seq_len[query_seq_len] = torch.cat( + [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))] + ) + + return cls(attn_masks_by_seq_len) + + def get_attn_mask(self, query_seq_len: int) -> torch.Tensor: + """Get the attention mask for the given query sequence length (i.e. downscaling level). + + This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so + it changes at each downscaling level in the model. + + key_seq_len is the length of the expected prompt embeddings. + + Returns: + torch.Tensor: The masks. + shape: (batch_size, query_seq_len, key_seq_len). + dtype: float + The mask is a binary mask with values of 0.0 and 1.0. + """ + return self._attn_masks_by_seq_len[query_seq_len] diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index d66f7d83a1..00ec43dd6b 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -16,7 +16,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( TextConditioningData, TextConditioningRegions, ) -from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from .cross_attention_control import ( CrossAttentionType, @@ -303,19 +303,13 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - cross_attention_kwargs = None - - # TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning. + cross_attention_kwargs = {} if ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.stack( - [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] - ) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) + for ipa_conditioning in ip_adapter_conditioning + ] uncond_text = conditioning_data.uncond_text cond_text = conditioning_data.cond_text @@ -352,9 +346,9 @@ class InvokeAIDiffuserComponent: regions.append(r) _, key_seq_len, _ = both_conditionings.shape - cross_attention_kwargs = { - "regional_prompt_data": RegionalPromptData.from_regions(regions=regions, key_seq_len=key_seq_len) - } + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( + regions=regions, key_seq_len=key_seq_len + ) both_results = self.model_forward_callback( x_twice, @@ -424,21 +418,19 @@ class InvokeAIDiffuserComponent: # Unconditioned pass ##################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the unconditioned pass. if cross_attn_processor_context is not None: - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None @@ -451,11 +443,9 @@ class InvokeAIDiffuserComponent: # Prepare prompt regions for the unconditioned pass. if conditioning_data.uncond_regions is not None: _, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape - cross_attention_kwargs = { - "regional_prompt_data": RegionalPromptData.from_regions( - regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len - ) - } + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( + regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len + ) # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( @@ -473,22 +463,20 @@ class InvokeAIDiffuserComponent: # Conditioned pass ################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the conditioned pass. if cross_attn_processor_context is not None: cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None @@ -501,11 +489,9 @@ class InvokeAIDiffuserComponent: # Prepare prompt regions for the conditioned pass. if conditioning_data.cond_regions is not None: _, key_seq_len, _ = conditioning_data.cond_text.embeds.shape - cross_attention_kwargs = { - "regional_prompt_data": RegionalPromptData.from_regions( - regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len - ) - } + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions( + regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len + ) # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( diff --git a/invokeai/backend/ip_adapter/unet_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py similarity index 53% rename from invokeai/backend/ip_adapter/unet_patcher.py rename to invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index f8c1870f6e..344768a9bf 100644 --- a/invokeai/backend/ip_adapter/unet_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -1,52 +1,55 @@ from contextlib import contextmanager +from typing import Optional from diffusers.models import UNet2DConditionModel -from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0 from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0 -class UNetPatcher: - """A class that contains multiple IP-Adapters and can apply them to a UNet.""" +class UNetAttentionPatcher: + """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" - def __init__(self, ip_adapters: list[IPAdapter]): + def __init__(self, ip_adapters: Optional[list[IPAdapter]]): self._ip_adapters = ip_adapters - self._scales = [1.0] * len(self._ip_adapters) + self._ip_adapter_scales = None + + if self._ip_adapters is not None: + self._ip_adapter_scales = [1.0] * len(self._ip_adapters) def set_scale(self, idx: int, value: float): - self._scales[idx] = value + self._ip_adapter_scales[idx] = value def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention - weights into them. + weights into them (if IP-Adapters are being applied). Note that the `unet` param is only used to determine attention block dimensions and naming. """ # Construct a dict of attention processors based on the UNet's architecture. attn_procs = {} for idx, name in enumerate(unet.attn_processors.keys()): - if name.endswith("attn1.processor"): - attn_procs[name] = AttnProcessor2_0() + if name.endswith("attn1.processor") or self._ip_adapters is None: + # "attn1" processors do not use IP-Adapters. + attn_procs[name] = CustomAttnProcessor2_0() else: # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = IPAttnProcessor2_0( + attn_procs[name] = CustomAttnProcessor2_0( [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], - self._scales, + self._ip_adapter_scales, ) return attn_procs @contextmanager def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): - """A context manager that patches `unet` with IP-Adapter attention processors.""" - + """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" attn_procs = self._prepare_attention_processors(unet) - orig_attn_processors = unet.attn_processors try: - # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the - # passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy - # of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. + # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from + # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a + # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. unet.set_attn_processor(attn_procs) yield None finally: diff --git a/tests/backend/ip_adapter/test_ip_adapter.py b/tests/backend/ip_adapter/test_ip_adapter.py index 6a3ec510a2..49bfb8d296 100644 --- a/tests/backend/ip_adapter/test_ip_adapter.py +++ b/tests/backend/ip_adapter/test_ip_adapter.py @@ -1,8 +1,8 @@ import pytest import torch -from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.test_utils import install_and_load_model @@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device): ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} - ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) + ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter]) with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample