From 93056e4ab7b3adfa7c632b4895c76499a78d67c8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 8 Mar 2024 13:42:35 -0500 Subject: [PATCH] Add support for lists of prompt embeddings to be passed to the DenoiseLatents invocation, and add handling of the conditioning region masks in DenoiseLatents. --- invokeai/app/invocations/latent.py | 185 +++++++++++++++++++++++++++-- 1 file changed, 172 insertions(+), 13 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 0d894dcee4..f2e1822c30 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,4 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) - import inspect import math from contextlib import ExitStack @@ -10,6 +9,7 @@ import einops import numpy as np import numpy.typing as npt import torch +import torchvision import torchvision.transforms as T from diffusers import AutoencoderKL, AutoencoderTiny from diffusers.configuration_utils import ConfigMixin @@ -58,8 +58,12 @@ from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + BasicConditioningInfo, IPAdapterConditioningInfo, + Range, + SDXLConditioningInfo, TextConditioningData, + TextConditioningRegions, ) from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -288,10 +292,10 @@ def get_scheduler( class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" - positive_conditioning: ConditioningField = InputField( + positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 ) - negative_conditioning: ConditioningField = InputField( + negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 ) noise: Optional[LatentsField] = InputField( @@ -369,26 +373,177 @@ class DenoiseLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v + def _get_text_embeddings_and_masks( + self, + cond_list: list[ConditioningField], + context: InvocationContext, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]: + """Get the text embeddings and masks from the input conditioning fields.""" + text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] + text_embeddings_masks: list[Optional[torch.Tensor]] = [] + for cond in cond_list: + cond_data = context.conditioning.load(cond.conditioning_name) + text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) + + mask = cond.mask + if mask is not None: + mask = context.tensors.load(mask.mask_name) + text_embeddings_masks.append(mask) + + return text_embeddings, text_embeddings_masks + + def _preprocess_regional_prompt_mask( + self, mask: Optional[torch.Tensor], target_height: int, target_width: int + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). + """ + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + + def _concat_regional_text_embeddings( + self, + text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], + masks: Optional[list[Optional[torch.Tensor]]], + latent_height: int, + latent_width: int, + ) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: + """Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" + if masks is None: + masks = [None] * len(text_conditionings) + assert len(text_conditionings) == len(masks) + + is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo + + all_masks_are_none = all(mask is None for mask in masks) + + text_embedding = [] + pooled_embedding = None + add_time_ids = None + cur_text_embedding_len = 0 + processed_masks = [] + embedding_ranges = [] + extra_conditioning = None + + for prompt_idx, text_embedding_info in enumerate(text_conditionings): + mask = masks[prompt_idx] + if ( + text_embedding_info.extra_conditioning is not None + and text_embedding_info.extra_conditioning.wants_cross_attention_control + ): + extra_conditioning = text_embedding_info.extra_conditioning + + if is_sdxl: + # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for + # prompts without a mask. We prefer prompts without a mask, because they are more likely to contain + # global prompt information. In an ideal case, there should be exactly one global prompt without a + # mask, but we don't enforce this. + + # HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a + # fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use + # them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single + # pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a + # pretty major breaking change to a popular node, so for now we use this hack. + if pooled_embedding is None or mask is None: + pooled_embedding = text_embedding_info.pooled_embeds + if add_time_ids is None or mask is None: + add_time_ids = text_embedding_info.add_time_ids + + text_embedding.append(text_embedding_info.embeds) + if not all_masks_are_none: + embedding_ranges.append( + Range( + start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1] + ) + ) + processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width)) + + cur_text_embedding_len += text_embedding_info.embeds.shape[1] + + text_embedding = torch.cat(text_embedding, dim=1) + assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + + regions = None + if not all_masks_are_none: + regions = TextConditioningRegions( + masks=torch.cat(processed_masks, dim=1), + ranges=embedding_ranges, + ) + + if extra_conditioning is not None and len(text_conditionings) > 1: + raise ValueError( + "Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple " + "prompts." + ) + + if is_sdxl: + return SDXLConditioningInfo( + embeds=text_embedding, + extra_conditioning=extra_conditioning, + pooled_embeds=pooled_embedding, + add_time_ids=add_time_ids, + ), regions + return BasicConditioningInfo( + embeds=text_embedding, + extra_conditioning=extra_conditioning, + ), regions + def get_conditioning_data( self, context: InvocationContext, unet: UNet2DConditionModel, + latent_height: int, + latent_width: int, ) -> TextConditioningData: - positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) - c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + # Normalize self.positive_conditioning and self.negative_conditioning to lists. + cond_list = self.positive_conditioning + if not isinstance(cond_list, list): + cond_list = [cond_list] + uncond_list = self.negative_conditioning + if not isinstance(uncond_list, list): + uncond_list = [uncond_list] - negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) - uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) + cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + cond_list, context, unet.device, unet.dtype + ) + uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + uncond_list, context, unet.device, unet.dtype + ) + + cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( + text_conditionings=cond_text_embeddings, + masks=cond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + ) + uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( + text_conditionings=uncond_text_embeddings, + masks=uncond_text_embedding_masks, + latent_height=latent_height, + latent_width=latent_width, + ) conditioning_data = TextConditioningData( - uncond_text=uc, - cond_text=c, - uncond_regions=None, - cond_regions=None, + uncond_text=uncond_text_embedding, + cond_text=cond_text_embedding, + uncond_regions=uncond_regions, + cond_regions=cond_regions, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, ) - return conditioning_data def create_pipeline( @@ -758,7 +913,11 @@ class DenoiseLatentsInvocation(BaseInvocation): ) pipeline = self.create_pipeline(unet, scheduler) - conditioning_data = self.get_conditioning_data(context, unet) + + _, _, latent_height, latent_width = latents.shape + conditioning_data = self.get_conditioning_data( + context=context, unet=unet, latent_height=latent_height, latent_width=latent_width + ) controlnet_data = self.prep_control_data( context=context,