From 54971afe448996e5c7bfa5b5ead5b28c00edbc6e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 27 Feb 2024 20:05:02 -0500 Subject: [PATCH] Add symmetric support for regional negative text prompts. --- invokeai/app/invocations/latent.py | 61 ++++++++++++------- .../stable_diffusion/diffusers_pipeline.py | 5 +- .../diffusion/conditioning_data.py | 8 +-- .../diffusion/shared_invokeai_diffusion.py | 52 ++++++---------- 4 files changed, 65 insertions(+), 61 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 7716457843..dec2bcf7f7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -44,6 +44,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningData, IPAdapterConditioningInfo, + SDXLConditioningInfo, ) from ...backend.model_management.lora import ModelPatcher @@ -233,8 +234,8 @@ class DenoiseLatentsInvocation(BaseInvocation): positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 ) - negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 + negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( + description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0 ) noise: Optional[LatentsField] = InputField( default=None, @@ -327,6 +328,31 @@ class DenoiseLatentsInvocation(BaseInvocation): base_model=base_model, ) + def _get_text_embeddings_and_masks( + self, + cond_field: Union[ConditioningField, list[ConditioningField]], + context: InvocationContext, + device: torch.device, + dtype: torch.dtype, + ): + # Normalize cond_field to a list. + cond_list = cond_field + if not isinstance(cond_list, list): + cond_list = [cond_list] + + text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] + text_embeddings_masks: list[Optional[torch.Tensor]] = [] + for cond in cond_list: + cond_data = context.services.latents.get(cond.conditioning_name) + text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) + + mask = cond.mask + if mask is not None: + mask = context.services.latents.get(mask.mask_name) + text_embeddings_masks.append(mask) + + return text_embeddings, text_embeddings_masks + def get_conditioning_data( self, context: InvocationContext, @@ -334,29 +360,18 @@ class DenoiseLatentsInvocation(BaseInvocation): unet, seed, ) -> ConditioningData: - # self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here. - positive_conditioning_list = self.positive_conditioning - if not isinstance(positive_conditioning_list, list): - positive_conditioning_list = [positive_conditioning_list] - - text_embeddings: list[BasicConditioningInfo] = [] - text_embeddings_masks: list[Optional[torch.Tensor]] = [] - for positive_conditioning in positive_conditioning_list: - positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) - text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) - - mask = positive_conditioning.mask - if mask is not None: - mask = context.services.latents.get(mask.mask_name) - text_embeddings_masks.append(mask) - - negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + self.positive_conditioning, context, unet.device, unet.dtype + ) + uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + self.negative_conditioning, context, unet.device, unet.dtype + ) conditioning_data = ConditioningData( - unconditioned_embeddings=uc, - text_embeddings=text_embeddings, - text_embedding_masks=text_embeddings_masks, + uncond_text_embeddings=uncond_text_embeddings, + uncond_text_embedding_masks=uncond_text_embedding_masks, + cond_text_embeddings=cond_text_embeddings, + cond_text_embedding_masks=cond_text_embedding_masks, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, postprocessing_settings=PostprocessingSettings( diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index febd9ad792..ca1406ed4a 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -404,12 +404,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents - extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning + extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning use_cross_attention_control = ( extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control ) use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = len(conditioning_data.text_embeddings) > 1 + # HACK(ryand): Fix this logic. + use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1 if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1: raise Exception( "Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)." diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 83bb78e42e..9c7609305e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -65,10 +65,10 @@ class IPAdapterConditioningInfo: @dataclass class ConditioningData: - # TODO(ryand): Support masks for unconditioned_embeddings. - unconditioned_embeddings: BasicConditioningInfo - text_embeddings: list[BasicConditioningInfo] - text_embedding_masks: list[Optional[torch.Tensor]] + uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] + uncond_text_embedding_masks: list[Optional[torch.Tensor]] + cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] + cond_text_embedding_masks: list[Optional[torch.Tensor]] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 940eafe69c..f22cf1375e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -234,7 +234,8 @@ class InvokeAIDiffuserComponent: down_block_res_samples, mid_block_res_sample = None, None # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably # concatenate all of the embeddings for the ControlNet, but not apply embedding masks. - text_embeddings = conditioning_data.text_embeddings[0] + uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0] + cond_text_embeddings = conditioning_data.cond_text_embeddings[0] # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) @@ -265,38 +266,25 @@ class InvokeAIDiffuserComponent: added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if type(text_embeddings) is SDXLConditioningInfo: + if type(cond_text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { - "text_embeds": text_embeddings.pooled_embeds, - "time_ids": text_embeddings.add_time_ids, + "text_embeds": cond_text_embeddings.pooled_embeds, + "time_ids": cond_text_embeddings.add_time_ids, } - encoder_hidden_states = text_embeddings.embeds + encoder_hidden_states = cond_text_embeddings.embeds encoder_attention_mask = None else: - if type(text_embeddings) is SDXLConditioningInfo: + if type(cond_text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - text_embeddings.pooled_embeds, - ], - dim=0, + [uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0 ), "time_ids": torch.cat( - [ - conditioning_data.unconditioned_embeddings.add_time_ids, - text_embeddings.add_time_ids, - ], - dim=0, + [uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0 ), } - ( - encoder_hidden_states, - encoder_attention_mask, - ) = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, - text_embeddings.embeds, + (encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch( + uncond_text_embeddings.embeds, cond_text_embeddings.embeds ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step @@ -487,14 +475,14 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs = None _, _, h, w = x.shape cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks( - text_conditionings=conditioning_data.text_embeddings, - masks=conditioning_data.text_embedding_masks, + text_conditionings=conditioning_data.cond_text_embeddings, + masks=conditioning_data.cond_text_embedding_masks, latent_height=h, latent_width=w, ) uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks( - text_conditionings=[conditioning_data.unconditioned_embeddings], - masks=[None], + text_conditionings=conditioning_data.uncond_text_embeddings, + masks=conditioning_data.uncond_text_embedding_masks, latent_height=h, latent_width=w, ) @@ -579,8 +567,8 @@ class InvokeAIDiffuserComponent: slower execution speed. """ - assert len(conditioning_data.text_embeddings) == 1 - text_embeddings = conditioning_data.text_embeddings[0] + assert len(conditioning_data.cond_text_embeddings) == 1 + text_embeddings = conditioning_data.cond_text_embeddings[0] # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # and T2I-Adapter residuals into two chunks. @@ -642,15 +630,15 @@ class InvokeAIDiffuserComponent: is_sdxl = type(text_embeddings) is SDXLConditioningInfo if is_sdxl: added_cond_kwargs = { - "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, - "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, + "text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds, + "time_ids": conditioning_data.uncond_text_embeddings.add_time_ids, } # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.unconditioned_embeddings.embeds, + conditioning_data.uncond_text_embeddings.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block,