From caa690e24d54b8984eb404f240c5b531ac58d1e0 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 16 Feb 2024 17:09:06 -0500 Subject: [PATCH] Add concatenation of multiple text conditioning tensors, and patching of RegionalPromptAttnProcessor2_0 into the UNet. --- .../stable_diffusion/diffusers_pipeline.py | 6 +- .../diffusion/regional_prompt_attention.py | 18 ++++++ .../diffusion/shared_invokeai_diffusion.py | 55 ++++++++++++++++++- 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index d34016d128..febd9ad792 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -25,6 +25,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..util import auto_detect_slice_size, normalize_device @@ -415,20 +416,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) ip_adapter_unet_patcher = None + self.use_ip_adapter = use_ip_adapter if use_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, ) - self.use_ip_adapter = False elif use_ip_adapter: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - self.use_ip_adapter = True elif use_regional_prompting: - raise NotImplementedError("Regional prompting is not yet supported.") + attn_ctx = apply_regional_prompt_attn(self.invokeai_diffuser.model) else: attn_ctx = nullcontext() diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index ef0790b3cb..6d4ef77745 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -1,7 +1,9 @@ +from contextlib import contextmanager from typing import Optional import torch import torch.nn.functional as F +from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.utils import USE_PEFT_BACKEND @@ -17,7 +19,10 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, + regional_prompt_data=None, ) -> torch.FloatTensor: + assert regional_prompt_data is None + residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -83,3 +88,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): hidden_states = hidden_states / attn.rescale_output_factor return hidden_states + + +@contextmanager +def apply_regional_prompt_attn(unet: UNet2DConditionModel): + """A context manager that patches `unet` with RegionalPromptAttnProcessor2_0 attention processors.""" + + orig_attn_processors = unet.attn_processors + + try: + unet.set_attn_processor(RegionalPromptAttnProcessor2_0()) + yield None + finally: + unet.set_attn_processor(orig_attn_processors) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index a2ec7fc891..f202c121ff 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -10,6 +10,7 @@ from typing_extensions import TypeAlias from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, ConditioningData, ExtraConditioningInfo, PostprocessingSettings, @@ -309,6 +310,55 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. + def _prepare_text_embeddings( + self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]] + ) -> Union[BasicConditioningInfo, SDXLConditioningInfo]: + if len(text_embeddings) == 1: + # If there is only one text embedding, we can just return it. + # We short-circuit here, because there are some features that are only supported when there is a single + # text_embedding provided. + return text_embeddings[0] + + is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo + + text_embedding = [] + pooled_embedding = None + add_time_ids = None + + for text_embedding_info in text_embeddings: + # TODO(ryand): Having to check this feels super hacky. + # Extra conditioning is not supported when there are multiple text embeddings. + assert ( + text_embedding_info.extra_conditioning is None + or not text_embedding_info.extra_conditioning.wants_cross_attention_control + ) + + if is_sdxl: + # We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. + # TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all + # the conditioning info, then we shouldn't allow it to be passed in. + if pooled_embedding is None: + pooled_embedding = text_embedding_info.pooled_embeds + if add_time_ids is None: + add_time_ids = text_embedding_info.add_time_ids + + text_embedding.append(text_embedding_info.embeds) + + text_embedding = torch.cat(text_embedding, dim=1) + assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + + if is_sdxl: + return SDXLConditioningInfo( + embeds=text_embedding, + extra_conditioning=None, + pooled_embeds=pooled_embedding, + add_time_ids=add_time_ids, + ) + return BasicConditioningInfo( + embeds=text_embedding, + extra_conditioning=None, + ) + def _apply_standard_conditioning( self, x, @@ -324,8 +374,9 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - assert len(conditioning_data.text_embeddings) == 1 - text_embeddings = conditioning_data.text_embeddings[0] + text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings) + if len(conditioning_data.text_embeddings) > 1: + cross_attention_kwargs = {"regional_prompt_data": None} cross_attention_kwargs = None if conditioning_data.ip_adapter_conditioning is not None: