From f590b39f886fcb39342ab5a7efd7fb212deb31fb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 15 Feb 2024 14:41:54 -0500 Subject: [PATCH] Add support for a list of ConditioningFields in DenoiseLatents. --- invokeai/app/invocations/latent.py | 19 ++++++-- .../stable_diffusion/diffusers_pipeline.py | 18 ++++++-- .../diffusion/conditioning_data.py | 6 +-- .../diffusion/shared_invokeai_diffusion.py | 44 ++++++++++++------- 4 files changed, 58 insertions(+), 29 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 8769e65652..4b328f2af7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -40,7 +40,11 @@ from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, + ConditioningData, + IPAdapterConditioningInfo, +) from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.models import BaseModelType @@ -330,15 +334,22 @@ class DenoiseLatentsInvocation(BaseInvocation): unet, seed, ) -> ConditioningData: - positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + # self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here. + positive_conditioning_list = self.positive_conditioning + if not isinstance(positive_conditioning_list, list): + positive_conditioning_list = [positive_conditioning_list] + + text_embeddings: list[BasicConditioningInfo] = [] + for positive_conditioning in positive_conditioning_list: + positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) + text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( unconditioned_embeddings=uc, - text_embeddings=c, + text_embeddings=text_embeddings, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, postprocessing_settings=PostprocessingSettings( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 0bd30f4e55..1b725e58ee 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -419,21 +419,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents, attention_map_saver + extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning + use_cross_attention_control = ( + extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + ) + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = len(conditioning_data.text_embeddings) > 1 + if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1: + raise Exception( + "Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)." + ) + ip_adapter_unet_patcher = None - extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + if use_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, step_count=len(self.scheduler.timesteps), ) self.use_ip_adapter = False - elif ip_adapter_data is not None: + elif use_ip_adapter: # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # As it is now, the IP-Adapter will silently be skipped. ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) self.use_ip_adapter = True + elif use_regional_prompting: + raise NotImplementedError("Regional prompting is not yet supported.") else: attn_ctx = nullcontext() diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 6b66d94f4d..2fa66632b4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -62,7 +62,7 @@ class IPAdapterConditioningInfo: @dataclass class ConditioningData: unconditioned_embeddings: BasicConditioningInfo - text_embeddings: BasicConditioningInfo + text_embeddings: list[BasicConditioningInfo] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). @@ -82,10 +82,6 @@ class ConditioningData: ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None - @property - def dtype(self): - return self.text_embeddings.dtype - def add_scheduler_args_if_applicable(self, scheduler, **kwargs): scheduler_args = dict(self.scheduler_args) step_method = inspect.signature(scheduler.step) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 2b72c808e4..e77be6e1f1 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -116,9 +116,12 @@ class InvokeAIDiffuserComponent: timestep: torch.Tensor, step_index: int, total_step_count: int, - conditioning_data, + conditioning_data: ConditioningData, ): down_block_res_samples, mid_block_res_sample = None, None + # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably + # concatenate all of the embeddings for the ControlNet, but not apply embedding masks. + text_embeddings = conditioning_data.text_embeddings[0] # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) @@ -149,28 +152,28 @@ class InvokeAIDiffuserComponent: added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if type(text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": text_embeddings.add_time_ids, } - encoder_hidden_states = conditioning_data.text_embeddings.embeds + encoder_hidden_states = text_embeddings.embeds encoder_attention_mask = None else: - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if type(text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + text_embeddings.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + text_embeddings.add_time_ids, ], dim=0, ), @@ -180,7 +183,7 @@ class InvokeAIDiffuserComponent: encoder_attention_mask, ) = self._concat_conditionings_for_batch( conditioning_data.unconditioned_embeddings.embeds, - conditioning_data.text_embeddings.embeds, + text_embeddings.embeds, ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step @@ -346,6 +349,9 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) + assert len(conditioning_data.text_embeddings) == 1 + text_embeddings = conditioning_data.text_embeddings[0] + cross_attention_kwargs = None if conditioning_data.ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). @@ -359,27 +365,27 @@ class InvokeAIDiffuserComponent: } added_cond_kwargs = None - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if type(text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + text_embeddings.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + text_embeddings.add_time_ids, ], dim=0, ), } both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds + conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds ) both_results = self.model_forward_callback( x_twice, @@ -408,6 +414,10 @@ class InvokeAIDiffuserComponent: """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of slower execution speed. """ + + assert len(conditioning_data.text_embeddings) == 1 + text_embeddings = conditioning_data.text_embeddings[0] + # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # and T2I-Adapter residuals into two chunks. uncond_down_block, cond_down_block = None, None @@ -465,7 +475,7 @@ class InvokeAIDiffuserComponent: # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo + is_sdxl = type(text_embeddings) is SDXLConditioningInfo if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, @@ -509,15 +519,15 @@ class InvokeAIDiffuserComponent: added_cond_kwargs = None if is_sdxl: added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": text_embeddings.pooled_embeds, + "time_ids": text_embeddings.add_time_ids, } # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.text_embeddings.embeds, + text_embeddings.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block,