From e132afb705f37703be4f91111d5bcc282080dca3 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 28 Feb 2024 21:21:50 -0500 Subject: [PATCH] Make regional prompting work with sequential conditioning. --- .../diffusion/regional_prompt_attention.py | 12 +++++++----- .../diffusion/shared_invokeai_diffusion.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index c433199f6d..74d7c2755d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -126,19 +126,21 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - if encoder_hidden_states is not None: - assert regional_prompt_data is not None - assert attention_mask is not None + if encoder_hidden_states is not None and regional_prompt_data is not None: + # If encoder_hidden_states is not None, that means we are doing cross-attention case. _, query_seq_len, _ = hidden_states.shape prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len) # TODO(ryand): Avoid redundant type/device conversion here. prompt_region_attention_mask = prompt_region_attention_mask.to( - dtype=attention_mask.dtype, device=attention_mask.device + dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device ) prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0 prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0 - attention_mask = prompt_region_attention_mask + attention_mask + if attention_mask is None: + attention_mask = prompt_region_attention_mask + else: + attention_mask = prompt_region_attention_mask + attention_mask if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index c84de2b605..d66f7d83a1 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -448,6 +448,15 @@ class InvokeAIDiffuserComponent: "time_ids": conditioning_data.uncond_text.add_time_ids, } + # Prepare prompt regions for the unconditioned pass. + if conditioning_data.uncond_regions is not None: + _, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape + cross_attention_kwargs = { + "regional_prompt_data": RegionalPromptData.from_regions( + regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len + ) + } + # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, @@ -489,6 +498,15 @@ class InvokeAIDiffuserComponent: "time_ids": conditioning_data.cond_text.add_time_ids, } + # Prepare prompt regions for the conditioned pass. + if conditioning_data.cond_regions is not None: + _, key_seq_len, _ = conditioning_data.cond_text.embeds.shape + cross_attention_kwargs = { + "regional_prompt_data": RegionalPromptData.from_regions( + regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len + ) + } + # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x,