From c059bc31628811be5d448de439a74587bef44544 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 12:57:33 -0500 Subject: [PATCH] Add TextConditioningRegions to the TextConditioningData data structure. --- invokeai/app/invocations/latent.py | 6 +- .../stable_diffusion/diffusers_pipeline.py | 2 +- .../diffusion/conditioning_data.py | 58 +++++++++++++++---- .../diffusion/shared_invokeai_diffusion.py | 54 +++++++++-------- 4 files changed, 79 insertions(+), 41 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index be5ed91581..0d894dcee4 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -381,8 +381,10 @@ class DenoiseLatentsInvocation(BaseInvocation): uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = TextConditioningData( - unconditioned_embeddings=uc, - text_embeddings=c, + uncond_text=uc, + cond_text=c, + uncond_regions=None, + cond_regions=None, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 190cc9869f..7ef93b0bcb 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -405,7 +405,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return latents ip_adapter_unet_patcher = None - extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning + extra_conditioning_info = conditioning_data.cond_text.extra_conditioning if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 051b2fed1f..6ef6d68fca 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -63,14 +63,52 @@ class IPAdapterConditioningInfo: @dataclass +class Range: + start: int + end: int + + +class TextConditioningRegions: + def __init__( + self, + masks: torch.Tensor, + ranges: list[Range], + ): + # A binary mask indicating the regions of the image that the prompt should be applied to. + # Shape: (1, num_prompts, height, width) + # Dtype: torch.bool + self.masks = masks + + # A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to. + # ranges[i] contains the embedding range for the i'th prompt / mask. + self.ranges = ranges + + assert self.masks.shape[1] == len(self.ranges) + + class TextConditioningData: - unconditioned_embeddings: BasicConditioningInfo - text_embeddings: BasicConditioningInfo - # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). - # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate - # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - guidance_scale: Union[float, List[float]] - # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. - # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - guidance_rescale_multiplier: float = 0 + def __init__( + self, + uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], + cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], + uncond_regions: Optional[TextConditioningRegions], + cond_regions: Optional[TextConditioningRegions], + guidance_scale: Union[float, List[float]], + guidance_rescale_multiplier: float = 0, + ): + self.uncond_text = uncond_text + self.cond_text = cond_text + self.uncond_regions = uncond_regions + self.cond_regions = cond_regions + # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). + # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate + # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + self.guidance_scale = guidance_scale + # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. + # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + self.guidance_rescale_multiplier = guidance_rescale_multiplier + + def is_sdxl(self): + assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) + return isinstance(self.cond_text, SDXLConditioningInfo) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 5108521982..46150d2621 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -12,7 +12,6 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ExtraConditioningInfo, IPAdapterConditioningInfo, - SDXLConditioningInfo, TextConditioningData, ) @@ -91,7 +90,7 @@ class InvokeAIDiffuserComponent: timestep: torch.Tensor, step_index: int, total_step_count: int, - conditioning_data, + conditioning_data: TextConditioningData, ): down_block_res_samples, mid_block_res_sample = None, None @@ -124,28 +123,28 @@ class InvokeAIDiffuserComponent: added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds, + "time_ids": conditioning_data.cond_text.add_time_ids, } - encoder_hidden_states = conditioning_data.text_embeddings.embeds + encoder_hidden_states = conditioning_data.cond_text.embeds encoder_attention_mask = None else: - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, ], dim=0, ), @@ -154,8 +153,8 @@ class InvokeAIDiffuserComponent: encoder_hidden_states, encoder_attention_mask, ) = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, - conditioning_data.text_embeddings.embeds, + conditioning_data.uncond_text.embeds, + conditioning_data.cond_text.embeds, ) if isinstance(control_datum.weight, list): # if controlnet has multiple weights, use the weight for the current step @@ -325,27 +324,27 @@ class InvokeAIDiffuserComponent: } added_cond_kwargs = None - if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: + if conditioning_data.is_sdxl(): added_cond_kwargs = { "text_embeds": torch.cat( [ # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.unconditioned_embeddings.pooled_embeds, - conditioning_data.text_embeddings.pooled_embeds, + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, ], dim=0, ), "time_ids": torch.cat( [ - conditioning_data.unconditioned_embeddings.add_time_ids, - conditioning_data.text_embeddings.add_time_ids, + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, ], dim=0, ), } both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds + conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds ) both_results = self.model_forward_callback( x_twice, @@ -432,18 +431,17 @@ class InvokeAIDiffuserComponent: # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo - if is_sdxl: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, - "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, + "text_embeds": conditioning_data.uncond_text.pooled_embeds, + "time_ids": conditioning_data.uncond_text.add_time_ids, } # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.unconditioned_embeddings.embeds, + conditioning_data.uncond_text.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, @@ -474,17 +472,17 @@ class InvokeAIDiffuserComponent: # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None - if is_sdxl: + if conditioning_data.is_sdxl(): added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, + "text_embeds": conditioning_data.cond_text.pooled_embeds, + "time_ids": conditioning_data.cond_text.add_time_ids, } # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, sigma, - conditioning_data.text_embeddings.embeds, + conditioning_data.cond_text.embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block,