From 2e27ed5f3db22805e8d829947eb8acb722f4e2d2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 14 Mar 2024 13:56:03 -0400 Subject: [PATCH] Pass IP-Adapter scales through the cross_attn_kwargs pathway, since they are the same for all attention layers. This change also helps to prepare for adding IP-Adapter region masks. --- invokeai/app/invocations/latent.py | 2 +- .../stable_diffusion/diffusers_pipeline.py | 39 +---------- .../diffusion/conditioning_data.py | 24 +++++++ .../diffusion/custom_atttention.py | 15 ++--- .../diffusion/regional_ip_data.py | 20 ++++++ .../diffusion/shared_invokeai_diffusion.py | 66 ++++++++++++------- .../diffusion/unet_attention_patcher.py | 3 - 7 files changed, 95 insertions(+), 74 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 345621b38f..ba668440b8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -56,6 +56,7 @@ from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_sea from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, IPAdapterConditioningInfo, + IPAdapterData, Range, SDXLConditioningInfo, TextConditioningData, @@ -66,7 +67,6 @@ from invokeai.backend.util.silence_warnings import SilenceWarnings from ...backend.stable_diffusion.diffusers_pipeline import ( ControlNetData, - IPAdapterData, StableDiffusionGeneratorPipeline, T2IAdapterData, image_resized_to_grid_as_tensor, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 25acd32fea..aff328553d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -21,9 +21,8 @@ from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config.config_default import get_config -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - IPAdapterConditioningInfo, + IPAdapterData, TextConditioningData, ) from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent @@ -152,17 +151,6 @@ class ControlNetData: resize_mode: str = Field(default="just_resize") -@dataclass -class IPAdapterData: - ip_adapter_model: IPAdapter - ip_adapter_conditioning: IPAdapterConditioningInfo - - # Either a single weight applied to all steps, or a list of weights for each step. - weight: Union[float, List[float]] = Field(default=1.0) - begin_step_percent: float = Field(default=0.0) - end_step_percent: float = Field(default=1.0) - - @dataclass class T2IAdapterData: """A structure containing the information required to apply conditioning from a single T2I-Adapter model.""" @@ -439,7 +427,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data=control_data, ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, - unet_attention_patcher=unet_attention_patcher, ) latents = step_output.prev_sample predicted_original = getattr(step_output, "pred_original_sample", None) @@ -471,7 +458,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - unet_attention_patcher: Optional[UNetAttentionPatcher] = None, ): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] @@ -486,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) - # handle IP-Adapter - if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer - for i, single_ip_adapter_data in enumerate(ip_adapter_data): - first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count) - last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count) - weight = ( - single_ip_adapter_data.weight[step_index] - if isinstance(single_ip_adapter_data.weight, List) - else single_ip_adapter_data.weight - ) - if step_index >= first_adapter_step and step_index <= last_adapter_step: - # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. - unet_attention_patcher.set_scale(i, weight) - else: - # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. - unet_attention_patcher.set_scale(i, 0.0) - # Handle ControlNet(s) down_block_additional_residuals = None mid_block_additional_residual = None @@ -545,17 +514,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): down_intrablock_additional_residuals = accum_adapter_state - ip_adapter_conditioning = None - if ip_adapter_data is not None: - ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] - uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( sample=latent_model_input, timestep=t, # TODO: debug how handled batched and non batched timesteps step_index=step_index, total_step_count=total_step_count, conditioning_data=conditioning_data, - ip_adapter_conditioning=ip_adapter_conditioning, + ip_adapter_data=ip_adapter_data, down_block_additional_residuals=down_block_additional_residuals, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 102de19428..7196802ed3 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,8 +1,11 @@ +import math from dataclasses import dataclass from typing import List, Optional, Union import torch +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter + @dataclass class BasicConditioningInfo: @@ -45,6 +48,27 @@ class IPAdapterConditioningInfo: """ +@dataclass +class IPAdapterData: + ip_adapter_model: IPAdapter + ip_adapter_conditioning: IPAdapterConditioningInfo + + # Either a single weight applied to all steps, or a list of weights for each step. + weight: Union[float, List[float]] = 1.0 + begin_step_percent: float = 0.0 + end_step_percent: float = 1.0 + + def scale_for_step(self, step_index: int, total_steps: int) -> float: + first_adapter_step = math.floor(self.begin_step_percent * total_steps) + last_adapter_step = math.ceil(self.end_step_percent * total_steps) + weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight + if step_index >= first_adapter_step and step_index <= last_adapter_step: + # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. + return weight + # Otherwise, set the IP-Adapter's scale to 0, so it has no effect. + return 0.0 + + @dataclass class Range: start: int diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 667fcd9a64..34f868306b 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights +from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @@ -54,15 +55,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): regional_prompt_data: Optional[RegionalPromptData] = None, percent_through: Optional[torch.FloatTensor] = None, # For IP-Adapter: - ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, + regional_ip_data: Optional[RegionalIPData] = None, ) -> torch.FloatTensor: """Apply attention. Args: regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to apply regional prompt masking. - ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch. - ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each - tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len). + regional_ip_data: The IP-Adapter data for the current batch. """ # If true, we are doing cross-attention, if false we are doing self-attention. is_cross_attention = encoder_hidden_states is not None @@ -141,9 +140,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Apply IP-Adapter conditioning. if is_cross_attention and self._is_ip_adapter_enabled(): if self._is_ip_adapter_enabled(): - assert ip_adapter_image_prompt_embeds is not None + assert regional_ip_data is not None for ipa_embed, ipa_weights, scale in zip( - ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True + regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True ): # The batch dimensions should match. assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] @@ -178,8 +177,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): hidden_states = hidden_states + scale * ip_hidden_states else: - # If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in. - assert ip_adapter_image_prompt_embeds is None + # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. + assert regional_ip_data is None # Start unmodified block from AttnProcessor2_0. # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py new file mode 100644 index 0000000000..ecf878b416 --- /dev/null +++ b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py @@ -0,0 +1,20 @@ +import torch + + +class RegionalIPData: + """A class to manage the data for regional IP-Adapter conditioning.""" + + def __init__( + self, + image_prompt_embeds: list[torch.Tensor], + scales: list[float], + ): + """Initialize a `IPAdapterConditioningData` object.""" + # The image prompt embeddings. + # regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor + # has shape (batch_size, num_ip_images, seq_len, ip_embedding_len). + self.image_prompt_embeds = image_prompt_embeds + + # The scales for the IP-Adapter attention. + # scales[i] contains the attention scale for the i'th IP-Adapter. + self.scales = scales diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f565f58352..05b4a6406d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -8,11 +8,12 @@ from typing_extensions import TypeAlias from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - IPAdapterConditioningInfo, + IPAdapterData, Range, TextConditioningData, TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData ModelForwardCallback: TypeAlias = Union[ @@ -169,15 +170,13 @@ class InvokeAIDiffuserComponent: sample: torch.Tensor, timestep: torch.Tensor, conditioning_data: TextConditioningData, - ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], + ip_adapter_data: Optional[list[IPAdapterData]], step_index: int, total_step_count: int, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): - percent_through = step_index / total_step_count - if self.sequential_guidance: ( unconditioned_next_x, @@ -186,8 +185,9 @@ class InvokeAIDiffuserComponent: x=sample, sigma=timestep, conditioning_data=conditioning_data, - ip_adapter_conditioning=ip_adapter_conditioning, - percent_through=percent_through, + ip_adapter_data=ip_adapter_data, + step_index=step_index, + total_step_count=total_step_count, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -200,8 +200,9 @@ class InvokeAIDiffuserComponent: x=sample, sigma=timestep, conditioning_data=conditioning_data, - ip_adapter_conditioning=ip_adapter_conditioning, - percent_through=percent_through, + ip_adapter_data=ip_adapter_data, + step_index=step_index, + total_step_count=total_step_count, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -260,15 +261,16 @@ class InvokeAIDiffuserComponent: def _apply_standard_conditioning( self, - x, - sigma, + x: torch.Tensor, + sigma: torch.Tensor, conditioning_data: TextConditioningData, - ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], - percent_through: float, + ip_adapter_data: Optional[list[IPAdapterData]], + step_index: int, + total_step_count: int, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter - ): + ) -> tuple[torch.Tensor, torch.Tensor]: """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at the cost of higher memory usage. """ @@ -276,12 +278,16 @@ class InvokeAIDiffuserComponent: sigma_twice = torch.cat([sigma] * 2) cross_attention_kwargs = {} - if ip_adapter_conditioning is not None: + if ip_adapter_data is not None: + ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). - cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + image_prompt_embeds = [ torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) for ipa_conditioning in ip_adapter_conditioning ] + scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] + regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) + cross_attention_kwargs["regional_ip_data"] = regional_ip_data added_cond_kwargs = None if conditioning_data.is_sdxl(): @@ -326,7 +332,7 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=regions, device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = percent_through + cross_attention_kwargs["percent_through"] = step_index / total_step_count both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds @@ -350,8 +356,9 @@ class InvokeAIDiffuserComponent: x: torch.Tensor, sigma, conditioning_data: TextConditioningData, - ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], - percent_through: float, + ip_adapter_data: Optional[list[IPAdapterData]], + step_index: int, + total_step_count: int, down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -388,13 +395,18 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs = {} # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. - if ip_adapter_conditioning is not None: + if ip_adapter_data is not None: + ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + image_prompt_embeds = [ torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) for ipa_conditioning in ip_adapter_conditioning ] + scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] + regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) + cross_attention_kwargs["regional_ip_data"] = regional_ip_data + # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None if conditioning_data.is_sdxl(): @@ -408,7 +420,7 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = percent_through + cross_attention_kwargs["percent_through"] = step_index / total_step_count # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( @@ -428,14 +440,18 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs = {} - # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. - if ip_adapter_conditioning is not None: + if ip_adapter_data is not None: + ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [ + image_prompt_embeds = [ torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) for ipa_conditioning in ip_adapter_conditioning ] + scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] + regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) + cross_attention_kwargs["regional_ip_data"] = regional_ip_data + # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None if conditioning_data.is_sdxl(): @@ -449,7 +465,7 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype ) - cross_attention_kwargs["percent_through"] = percent_through + cross_attention_kwargs["percent_through"] = step_index / total_step_count # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 364ec18da4..416430b525 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -17,9 +17,6 @@ class UNetAttentionPatcher: if self._ip_adapters is not None: self._ip_adapter_scales = [1.0] * len(self._ip_adapters) - def set_scale(self, idx: int, value: float): - self._ip_adapter_scales[idx] = value - def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention weights into them (if IP-Adapters are being applied).