From d183aa823cb3939700e68bd374c587dad11a341b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 20 Apr 2024 17:09:41 -0400 Subject: [PATCH] wip --- invokeai/app/invocations/latent.py | 57 ++++++---------- .../stable_diffusion/diffusers_pipeline.py | 2 +- .../diffusion/conditioning_data.py | 48 +++++++++++--- .../diffusion/custom_atttention.py | 36 ++++++---- .../diffusion/regional_prompt_data.py | 66 +++++++++++++++---- .../diffusion/shared_invokeai_diffusion.py | 39 +++++------ 6 files changed, 155 insertions(+), 93 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4534df89c1..c96bd6bb6d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -57,10 +57,10 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, IPAdapterConditioningInfo, IPAdapterData, - Range, + SDRegionalTextConditioning, SDXLConditioningInfo, + SDXLRegionalTextConditioning, TextConditioningData, - TextConditioningRegions, ) from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -408,19 +408,15 @@ class DenoiseLatentsInvocation(BaseInvocation): resized_mask = tf(mask) return resized_mask - def _concat_regional_text_embeddings( + def _prepare_regional_text_embeddings( self, text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], - masks: Optional[list[Optional[torch.Tensor]]], + masks: list[Optional[torch.Tensor]], latent_height: int, latent_width: int, dtype: torch.dtype, - ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: + ) -> Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning]: """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" - if masks is None: - masks = [None] * len(text_conditionings) - assert len(text_conditionings) == len(masks) - is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo all_masks_are_none = all(mask is None for mask in masks) @@ -428,9 +424,7 @@ class DenoiseLatentsInvocation(BaseInvocation): text_embedding = [] pooled_embedding = None add_time_ids = None - cur_text_embedding_len = 0 processed_masks = [] - embedding_ranges = [] for prompt_idx, text_embedding_info in enumerate(text_conditionings): mask = masks[prompt_idx] @@ -453,32 +447,21 @@ class DenoiseLatentsInvocation(BaseInvocation): text_embedding.append(text_embedding_info.embeds) if not all_masks_are_none: - embedding_ranges.append( - Range( - start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] - ) - ) processed_masks.append( self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) ) - cur_text_embedding_len += text_embedding_info.embeds.shape[1] - - text_embedding = torch.cat(text_embedding, dim=1) - assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len - - regions = None - if not all_masks_are_none: - regions = TextConditioningRegions( - masks=torch.cat(processed_masks, dim=1), - ranges=embedding_ranges, - ) - if is_sdxl: - return SDXLConditioningInfo( - embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids - ), regions - return BasicConditioningInfo(embeds=text_embedding), regions + return SDXLRegionalTextConditioning( + pooled_embeds=pooled_embedding, + add_time_ids=add_time_ids, + text_embeds=text_embedding, + masks=None if all_masks_are_none else processed_masks, + ) + return SDRegionalTextConditioning( + text_embeds=text_embedding, + masks=None if all_masks_are_none else processed_masks, + ) def get_conditioning_data( self, @@ -502,14 +485,14 @@ class DenoiseLatentsInvocation(BaseInvocation): uncond_list, context, unet.device, unet.dtype ) - cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( + cond_text = self._prepare_regional_text_embeddings( text_conditionings=cond_text_embeddings, masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, ) - uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( + uncond_text = self._prepare_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, @@ -518,10 +501,8 @@ class DenoiseLatentsInvocation(BaseInvocation): ) conditioning_data = TextConditioningData( - uncond_text=uncond_text_embedding, - cond_text=cond_text_embedding, - uncond_regions=uncond_regions, - cond_regions=cond_regions, + uncond_text=uncond_text, + cond_text=cond_text, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8b90c815ae..013a96363d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -386,7 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): use_ip_adapter = ip_adapter_data is not None use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts() ) unet_attention_patcher = None self.use_ip_adapter = use_ip_adapter diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df..15b3329f1e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -95,20 +95,50 @@ class TextConditioningRegions: assert self.masks.shape[1] == len(self.ranges) +class SDRegionalTextConditioning: + def __init__(self, text_embeds: list[torch.Tensor], masks: Optional[list[torch.Tensor]]): + if masks is not None: + assert len(text_embeds) == len(masks) + + # A list of text embeddings. text_embeds[i] contains the text embeddings for the i'th prompt. + self.text_embeds = text_embeds + # A list of masks indicating the regions of the image that the prompts should be applied to. masks[i] contains + # the mask for the i'th prompt. Each mask has shape (1, height, width). + self.masks = masks + + def uses_regional_prompts(self): + # If there is more than one prompt, we treat this as regional prompting, even if there are no masks, because + # the regional prompting logic is used to combine the information from multiple prompts. + return len(self.text_embeds) > 1 or self.masks is not None + + +class SDXLRegionalTextConditioning(SDRegionalTextConditioning): + def __init__( + self, + pooled_embeds: torch.Tensor, + add_time_ids: torch.Tensor, + text_embeds: list[torch.Tensor], + masks: Optional[list[torch.Tensor]], + ): + super().__init__(text_embeds, masks) + + # Pooled embeddings for the global prompt. + self.pooled_embeds = pooled_embeds + # Additional global conditioning inputs for SDXL. The name "time_ids" comes from diffusers, and is a bit of a + # misnomer. This Tensor contains original_size, crop_coords, and target_size conditioning. + self.add_time_ids = add_time_ids + + class TextConditioningData: def __init__( self, - uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], - cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], - uncond_regions: Optional[TextConditioningRegions], - cond_regions: Optional[TextConditioningRegions], + uncond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning], + cond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning], guidance_scale: Union[float, List[float]], guidance_rescale_multiplier: float = 0, ): self.uncond_text = uncond_text self.cond_text = cond_text - self.uncond_regions = uncond_regions - self.cond_regions = cond_regions # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate @@ -119,5 +149,7 @@ class TextConditioningData: self.guidance_rescale_multiplier = guidance_rescale_multiplier def is_sdxl(self): - assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) - return isinstance(self.cond_text, SDXLConditioningInfo) + assert isinstance(self.uncond_text, SDXLRegionalTextConditioning) == isinstance( + self.cond_text, SDXLRegionalTextConditioning + ) + return isinstance(self.cond_text, SDXLRegionalTextConditioning) diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 1334313fe6..5ff7b9a2ca 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -63,6 +63,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # If true, we are doing cross-attention, if false we are doing self-attention. is_cross_attention = encoder_hidden_states is not None + _, query_seq_len, _ = hidden_states.shape + if regional_prompt_data is not None and is_cross_attention: + assert percent_through is not None + prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len) + encoder_hidden_states = regional_prompt_data.text_embeds + # Start unmodified block from AttnProcessor2_0. # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv residual = hidden_states @@ -81,18 +87,26 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # End unmodified block from AttnProcessor2_0. - _, query_seq_len, _ = hidden_states.shape - # Handle regional prompt attention masks. - if regional_prompt_data is not None and is_cross_attention: - assert percent_through is not None - prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_seq_len, key_seq_len=sequence_length - ) + # Current: + # - Run attention once, with masking to control which tokens each pixel is *allowed* to pay attention to. + # New + # - Run attention on each prompt separately. (no masking) + # - Combine the results with a weighted sum. - if attention_mask is None: - attention_mask = prompt_region_attention_mask - else: - attention_mask = prompt_region_attention_mask + attention_mask + # _, query_seq_len, _ = hidden_states.shape + # Handle regional prompt attention masks. + # if regional_prompt_data is not None and is_cross_attention: + # assert percent_through is not None + # prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len) + + # prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + # query_seq_len=query_seq_len, key_seq_len=sequence_length + # ) + + # if attention_mask is None: + # attention_mask = prompt_region_attention_mask + # else: + # attention_mask = prompt_region_attention_mask + attention_mask # Start unmodified block from AttnProcessor2_0. # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index f09cc0a0d2..af396e7c2f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -1,9 +1,7 @@ import torch import torch.nn.functional as F -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - TextConditioningRegions, -) +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningRegions class RegionalPromptData: @@ -11,31 +9,71 @@ class RegionalPromptData: def __init__( self, - regions: list[TextConditioningRegions], + text_embeds: list[list[torch.Tensor]], + masks: list[list[torch.Tensor]], device: torch.device, dtype: torch.dtype, max_downscale_factor: int = 8, ): """Initialize a `RegionalPromptData` object. Args: - regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the - batch. + TODO(ryand): Update these docs. + + text_embeds (list[list[torch.Tensor]]): The text prompt embeddings. text_embeds[b][i] contains the embedding + for prompt i to be applied to batch image b. + masks (list[list[torch.Tensor]]): The masks indicating the spatial regions of the image that each prompt + applies to. masks[b][i] contains the mask for text_embeds[b][i]. + device (torch.device): The device to use for the attention masks. dtype (torch.dtype): The data type to use for the attention masks. max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor in steps of 2x. """ - self._regions = regions + + assert len(text_embeds) == len(masks) + for text_embeds_batch, masks_batch in zip(text_embeds, masks, strict=True): + assert len(text_embeds_batch) == len(masks_batch) + + self.prompt_count_by_batch_element = [len(text_embeds_batch) for text_embeds_batch in text_embeds] + + # Flattenand concat text_embeds. + text_embeds_flat_list: list[torch.Tensor] = [] + for text_embeds_batch in text_embeds: + text_embeds_flat_list.extend(text_embeds_batch) + # TODO(ryand): Or stack? + # TODO(ryand): Text embeds might not all be the same size (if there were long prompts). + self.text_embeds = torch.cat(text_embeds_flat_list, dim=0) + + # Flatten and concat masks. + masks_flat_list = [] + for mask_batch in masks: + masks_flat_list.extend(mask_batch) + self._masks = torch.cat(masks_flat_list, dim=0) + self._device = device self._dtype = dtype - # self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query - # sequence length of s. - self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks( - regions, max_downscale_factor - ) - self._negative_cross_attn_mask_score = -10000.0 - def _prepare_spatial_masks( + def get_masks(self, query_seq_len: int): + _, h, w = self._masks.shape + + # Determine the downscaling factor for the given query sequence length. + max_downscale_factor = 8 + downscale_factor = 1 + while downscale_factor <= max_downscale_factor: + if query_seq_len == (h // downscale_factor) * (w // downscale_factor): + break + downscale_factor *= 2 + + if query_seq_len != (h // downscale_factor) * (w // downscale_factor): + raise ValueError(f"Failed to find a mask downsampling factor for query sequence length: {query_seq_len}") + + target_h = h // downscale_factor + target_w = w // downscale_factor + mask_downscaled = torch.nn.functional.interpolate(self._masks, size=(target_h, target_w), mode="nearest") + + return mask_downscaled + + def _prepare_spatial_masks_old( self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 ) -> list[dict[int, torch.Tensor]]: """Prepare the spatial masks for all downscaling factors.""" diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49..37e6478dd4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -7,12 +7,7 @@ import torch from typing_extensions import TypeAlias from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - IPAdapterData, - Range, - TextConditioningData, - TextConditioningRegions, -) +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @@ -312,33 +307,35 @@ class InvokeAIDiffuserComponent: ), } - if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: + if conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts(): # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly # awkward to handle both standard conditioning and sequential conditioning further up the stack. - regions = [] - for c, r in [ - (conditioning_data.uncond_text, conditioning_data.uncond_regions), - (conditioning_data.cond_text, conditioning_data.cond_regions), - ]: - if r is None: - # Create a dummy mask and range for text conditioning that doesn't have region masks. + masks: list[list[torch.Tensor]] = [] + for text_conditioning in [conditioning_data.uncond_text, conditioning_data.cond_text]: + if text_conditioning.masks is None: + # Create a dummy mask for text conditioning that doesn't have region masks. _, _, h, w = x.shape - r = TextConditioningRegions( - masks=torch.ones((1, 1, h, w), dtype=x.dtype), - ranges=[Range(start=0, end=c.embeds.shape[1])], - ) - regions.append(r) + masks.append([torch.ones((1, 1, h, w), dtype=x.dtype)] * len(text_conditioning.text_embeds)) + else: + masks.append(text_conditioning.masks) cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( - regions=regions, device=x.device, dtype=x.dtype + text_embeds=[conditioning_data.uncond_text.text_embeds, conditioning_data.cond_text.text_embeds], + masks=masks, + device=x.device, + dtype=x.dtype, ) cross_attention_kwargs["percent_through"] = step_index / total_step_count + # Note: We pass in the *first* text_embeds entry for both unconditioned and conditioned text embeds. This is the + # desired behaviour under 'normal' conditions when there is a single text prompt. In cases where we are doing + # regional prompting with multiple prompts, this input will be ignored altogether and the prompt information + # will be passed via the RegionalPromptData object. both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds + conditioning_data.uncond_text.text_embeds[0], conditioning_data.cond_text.text_embeds[0] ) both_results = self.model_forward_callback( x_twice,