diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py index 2e46149271..323ae6c038 100644 --- a/invokeai/app/invocations/conditioning.py +++ b/invokeai/app/invocations/conditioning.py @@ -33,7 +33,7 @@ class AddConditioningMaskInvocation(BaseInvocation): def convert_image_to_mask(image: Image.Image) -> torch.Tensor: """Convert a PIL image to a uint8 mask tensor.""" np_image = np.array(image) - torch_image = torch.from_numpy(np_image[0, :, :]) + torch_image = torch.from_numpy(np_image[:, :, 0]) mask = torch_image >= 128 return mask.to(dtype=torch.uint8) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ba84005e91..dbebc2ab82 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -340,16 +340,24 @@ class DenoiseLatentsInvocation(BaseInvocation): positive_conditioning_list = [positive_conditioning_list] text_embeddings: list[BasicConditioningInfo] = [] + text_embeddings_masks: list[Optional[torch.Tensor]] = [] for positive_conditioning in positive_conditioning_list: positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) + mask_name = positive_conditioning.mask_name + mask = None + if mask_name is not None: + mask = context.services.latents.get(mask_name) + text_embeddings_masks.append(mask) + negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) conditioning_data = ConditioningData( unconditioned_embeddings=uc, text_embeddings=text_embeddings, + text_embedding_masks=text_embeddings_masks, guidance_scale=self.cfg_scale, guidance_rescale_multiplier=self.cfg_rescale_multiplier, postprocessing_settings=PostprocessingSettings( diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 2fa66632b4..fb104d7bab 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -63,6 +63,8 @@ class IPAdapterConditioningInfo: class ConditioningData: unconditioned_embeddings: BasicConditioningInfo text_embeddings: list[BasicConditioningInfo] + text_embedding_masks: list[Optional[torch.Tensor]] + """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index 6d4ef77745..963aa3aa99 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +from dataclasses import dataclass from typing import Optional import torch @@ -8,6 +9,26 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.utils import USE_PEFT_BACKEND +@dataclass +class Range: + start: int + end: int + + +@dataclass +class RegionalPromptData: + # The region masks for each prompt. + # shape: (batch_size, num_prompts, height, width) + # dtype: float* + # The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere. + masks: torch.Tensor + + # The embedding ranges for each prompt. + # The i'th mask is applied to the embeddings in: + # encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :] + embedding_ranges: list[Range] + + class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): """An attention processor that supports regional prompt attention for PyTorch 2.0.""" @@ -19,10 +40,8 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): attention_mask: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, - regional_prompt_data=None, + regional_prompt_data: Optional[RegionalPromptData] = None, ) -> torch.FloatTensor: - assert regional_prompt_data is None - residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f202c121ff..7dbe3586e7 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -2,9 +2,10 @@ from __future__ import annotations import math from contextlib import contextmanager -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Tuple, Union import torch +import torchvision from diffusers import UNet2DConditionModel from typing_extensions import TypeAlias @@ -16,6 +17,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( PostprocessingSettings, SDXLConditioningInfo, ) +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData from .cross_attention_control import ( CrossAttentionType, @@ -308,26 +310,43 @@ class InvokeAIDiffuserComponent: return torch.cat([unconditioning, conditioning]), encoder_attention_mask - # methods below are called from do_diffusion_step and should be considered private to this class. + def _preprocess_regional_prompt_mask( + self, mask: Optional[torch.Tensor], target_height: int, target_width: int + ) -> torch.Tensor: + if mask is None: + # HACK(ryand): Figure out how to know the target device/dtype. + return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda") + else: + # HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the + # CPU. + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w) + mask = tf(mask) + + return mask def _prepare_text_embeddings( - self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]] - ) -> Union[BasicConditioningInfo, SDXLConditioningInfo]: - if len(text_embeddings) == 1: - # If there is only one text embedding, we can just return it. - # We short-circuit here, because there are some features that are only supported when there is a single - # text_embedding provided. - return text_embeddings[0] - + self, + text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]], + masks: list[Optional[torch.Tensor]], + target_height: int, + target_width: int, + ) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]: is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo + all_masks_are_none = all(mask is None for mask in masks) + text_embedding = [] pooled_embedding = None add_time_ids = None + processed_masks = [] + cur_text_embedding_len = 0 + embedding_ranges: list[Range] = [] - for text_embedding_info in text_embeddings: - # TODO(ryand): Having to check this feels super hacky. - # Extra conditioning is not supported when there are multiple text embeddings. + for text_embedding_info, mask in zip(text_embeddings, masks, strict=True): + # HACK(ryand): Figure out the intended relationship between CAC and other conditioning features. assert ( text_embedding_info.extra_conditioning is None or not text_embedding_info.extra_conditioning.wants_cross_attention_control @@ -343,21 +362,35 @@ class InvokeAIDiffuserComponent: add_time_ids = text_embedding_info.add_time_ids text_embedding.append(text_embedding_info.embeds) + embedding_ranges.append( + Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]) + ) + cur_text_embedding_len += text_embedding_info.embeds.shape[1] + + if not all_masks_are_none: + processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width)) text_embedding = torch.cat(text_embedding, dim=1) assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len + regional_prompt_data = None + if not all_masks_are_none: + # TODO(ryand): Think about at what point a batch dimension should be added to the masks. + processed_masks = torch.cat(processed_masks, dim=1) + + regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges) + if is_sdxl: return SDXLConditioningInfo( embeds=text_embedding, extra_conditioning=None, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids, - ) + ), regional_prompt_data return BasicConditioningInfo( embeds=text_embedding, extra_conditioning=None, - ) + ), regional_prompt_data def _apply_standard_conditioning( self, @@ -374,11 +407,20 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings) - if len(conditioning_data.text_embeddings) > 1: - cross_attention_kwargs = {"regional_prompt_data": None} - + # HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every + # denoising step. cross_attention_kwargs = None + _, _, h, w = x.shape + text_embeddings, regional_prompt_data = self._prepare_text_embeddings( + text_embeddings=conditioning_data.text_embeddings, + masks=conditioning_data.text_embedding_masks, + target_height=h, + target_width=w, + ) + if regional_prompt_data is not None: + cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data} + + # TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning. if conditioning_data.ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). cross_attention_kwargs = {