diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 2c765c0380..c33f7b7370 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -404,22 +404,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents - ip_adapter_unet_patcher = None extra_conditioning_info = conditioning_data.cond_text.extra_conditioning - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + use_cross_attention_control = ( + extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control + ) + use_ip_adapter = ip_adapter_data is not None + use_regional_prompting = ( + conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None + ) + if use_cross_attention_control and use_ip_adapter: + raise ValueError( + "Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously." + ) + if use_cross_attention_control and use_regional_prompting: + raise ValueError( + "Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously." + ) + + unet_attention_patcher = None + self.use_ip_adapter = use_ip_adapter + attn_ctx = nullcontext() + if use_cross_attention_control: attn_ctx = self.invokeai_diffuser.custom_attention_context( self.invokeai_diffuser.model, extra_conditioning_info=extra_conditioning_info, ) - self.use_ip_adapter = False - elif ip_adapter_data is not None: - # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? - # As it is now, the IP-Adapter will silently be skipped. - ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) - attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - self.use_ip_adapter = True - else: - attn_ctx = nullcontext() + if use_ip_adapter or use_regional_prompting: + ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None + unet_attention_patcher = UNetAttentionPatcher(ip_adapters) + attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) with attn_ctx: if callback is not None: @@ -447,7 +460,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, - ip_adapter_unet_patcher=ip_adapter_unet_patcher, + unet_attention_patcher=unet_attention_patcher, ) latents = step_output.prev_sample predicted_original = getattr(step_output, "pred_original_sample", None) @@ -479,7 +492,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None, + unet_attention_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -506,10 +519,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) if step_index >= first_adapter_step and step_index <= last_adapter_step: # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - ip_adapter_unet_patcher.set_scale(i, weight) + unet_attention_patcher.set_scale(i, weight) else: # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - ip_adapter_unet_patcher.set_scale(i, 0.0) + unet_attention_patcher.set_scale(i, 0.0) # Handle ControlNet(s) down_block_additional_residuals = None diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 47f81ff7aa..2f7523dd46 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -88,13 +88,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # End unmodified block from AttnProcessor2_0. # Handle regional prompt attention masks. - if regional_prompt_data is not None: + if regional_prompt_data is not None and is_cross_attention: assert percent_through is not None _, query_seq_len, _ = hidden_states.shape - if is_cross_attention: - prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( - query_seq_len=query_seq_len, key_seq_len=sequence_length - ) + prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( + query_seq_len=query_seq_len, key_seq_len=sequence_length + ) if attention_mask is None: attention_mask = prompt_region_attention_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 46150d2621..8ba988a0eb 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -12,8 +12,11 @@ from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ExtraConditioningInfo, IPAdapterConditioningInfo, + Range, TextConditioningData, + TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from .cross_attention_control import ( CrossAttentionType, @@ -206,9 +209,9 @@ class InvokeAIDiffuserComponent: mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): + percent_through = step_index / total_step_count cross_attention_control_types_to_do = [] if self.cross_attention_control_context is not None: - percent_through = step_index / total_step_count cross_attention_control_types_to_do = ( self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) ) @@ -225,6 +228,7 @@ class InvokeAIDiffuserComponent: sigma=timestep, conditioning_data=conditioning_data, ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, @@ -239,6 +243,7 @@ class InvokeAIDiffuserComponent: sigma=timestep, conditioning_data=conditioning_data, ip_adapter_conditioning=ip_adapter_conditioning, + percent_through=percent_through, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -301,6 +306,7 @@ class InvokeAIDiffuserComponent: sigma, conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -311,17 +317,13 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - cross_attention_kwargs = None + cross_attention_kwargs = {} if ip_adapter_conditioning is not None: # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.stack( - [ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] - ) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) + for ipa_conditioning in ip_adapter_conditioning + ] added_cond_kwargs = None if conditioning_data.is_sdxl(): @@ -343,6 +345,31 @@ class InvokeAIDiffuserComponent: ), } + if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of + # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly + # awkward to handle both standard conditioning and sequential conditioning further up the stack. + regions = [] + for c, r in [ + (conditioning_data.uncond_text, conditioning_data.uncond_regions), + (conditioning_data.cond_text, conditioning_data.cond_regions), + ]: + if r is None: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + _, _, h, w = x.shape + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=torch.bool), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + regions.append(r) + + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=regions, device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds ) @@ -366,6 +393,7 @@ class InvokeAIDiffuserComponent: sigma, conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + percent_through: float, cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet @@ -413,21 +441,19 @@ class InvokeAIDiffuserComponent: # Unconditioned pass ##################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the unconditioned pass. if cross_attn_processor_context is not None: - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None @@ -437,6 +463,13 @@ class InvokeAIDiffuserComponent: "time_ids": conditioning_data.uncond_text.add_time_ids, } + # Prepare prompt regions for the unconditioned pass. + if conditioning_data.uncond_regions is not None: + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, @@ -453,22 +486,20 @@ class InvokeAIDiffuserComponent: # Conditioned pass ################### - cross_attention_kwargs = None + cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. if ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs = { - "ip_adapter_image_prompt_embeds": [ - torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - } + cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) + for ipa_conditioning in ip_adapter_conditioning + ] # Prepare cross-attention control kwargs for the conditioned pass. if cross_attn_processor_context is not None: cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None @@ -478,6 +509,13 @@ class InvokeAIDiffuserComponent: "time_ids": conditioning_data.cond_text.add_time_ids, } + # Prepare prompt regions for the conditioned pass. + if conditioning_data.cond_regions is not None: + cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( + regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype + ) + cross_attention_kwargs["percent_through"] = percent_through + # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x,