diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ba84005e91..2faf74055d 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -41,9 +41,9 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - BasicConditioningInfo, ConditioningData, IPAdapterConditioningInfo, + TextConditioningInfoWithMask, ) from ...backend.model_management.lora import ModelPatcher @@ -339,10 +339,20 @@ class DenoiseLatentsInvocation(BaseInvocation): if not isinstance(positive_conditioning_list, list): positive_conditioning_list = [positive_conditioning_list] - text_embeddings: list[BasicConditioningInfo] = [] + text_embeddings: list[TextConditioningInfoWithMask] = [] for positive_conditioning in positive_conditioning_list: positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) - text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) + mask_name = positive_conditioning.mask_name + mask = None + if mask_name is not None: + mask = context.services.latents.get(mask_name) + text_embeddings.append( + TextConditioningInfoWithMask( + text_conditioning_info=positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype), + mask=mask, + mask_strength=positive_conditioning.mask_strength, + ) + ) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index d34016d128..71909523e4 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -403,16 +403,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents - extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning + extra_conditioning_info = conditioning_data.text_embeddings[0].text_conditioning_info.extra_conditioning use_cross_attention_control = ( extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control ) use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = len(conditioning_data.text_embeddings) > 1 - if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1: - raise Exception( - "Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)." - ) + if sum([use_cross_attention_control, use_ip_adapter]) > 1: + raise Exception("Cross-attention control and IP-Adapter cannot be used simultaneously (yet).") ip_adapter_unet_patcher = None if use_cross_attention_control: @@ -427,8 +424,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) self.use_ip_adapter = True - elif use_regional_prompting: - raise NotImplementedError("Regional prompting is not yet supported.") else: attn_ctx = nullcontext() diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 2fa66632b4..485f23d7b1 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -39,6 +39,18 @@ class SDXLConditioningInfo(BasicConditioningInfo): return super().to(device=device, dtype=dtype) +class TextConditioningInfoWithMask: + def __init__( + self, + text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo], + mask: Optional[torch.Tensor], + mask_strength: float, + ): + self.text_conditioning_info = text_conditioning_info + self.mask = mask + self.mask_strength = mask_strength + + @dataclass(frozen=True) class PostprocessingSettings: threshold: float @@ -62,7 +74,7 @@ class IPAdapterConditioningInfo: @dataclass class ConditioningData: unconditioned_embeddings: BasicConditioningInfo - text_embeddings: list[BasicConditioningInfo] + text_embeddings: list[TextConditioningInfoWithMask] """ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index a2ec7fc891..10c76c43a4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -94,9 +94,9 @@ class InvokeAIDiffuserComponent: conditioning_data: ConditioningData, ): down_block_res_samples, mid_block_res_sample = None, None - # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably - # concatenate all of the embeddings for the ControlNet, but not apply embedding masks. - text_embeddings = conditioning_data.text_embeddings[0] + # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably run + # the controlnet separately for each conditioning input. + text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info # control_data should be type List[ControlNetData] # this loop covers both ControlNet (one ControlNetData in list) @@ -325,7 +325,7 @@ class InvokeAIDiffuserComponent: sigma_twice = torch.cat([sigma] * 2) assert len(conditioning_data.text_embeddings) == 1 - text_embeddings = conditioning_data.text_embeddings[0] + text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info cross_attention_kwargs = None if conditioning_data.ip_adapter_conditioning is not None: @@ -391,7 +391,7 @@ class InvokeAIDiffuserComponent: """ assert len(conditioning_data.text_embeddings) == 1 - text_embeddings = conditioning_data.text_embeddings[0] + text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # and T2I-Adapter residuals into two chunks.