Pass IP-Adapter scales through the cross_attn_kwargs pathway, since they are the same for all attention layers. This change also helps to prepare for adding IP-Adapter region masks.

This commit is contained in:
Ryan Dick 2024-03-14 13:56:03 -04:00 committed by Kent Keirsey
parent babdc64b17
commit 2e27ed5f3d
7 changed files with 95 additions and 74 deletions

View File

@ -56,6 +56,7 @@ from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_sea
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
IPAdapterData,
Range,
SDXLConditioningInfo,
TextConditioningData,
@ -66,7 +67,6 @@ from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor,

View File

@ -21,9 +21,8 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterConditioningInfo,
IPAdapterData,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -152,17 +151,6 @@ class ControlNetData:
resize_mode: str = Field(default="just_resize")
@dataclass
class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass
class T2IAdapterData:
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
@ -439,7 +427,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
unet_attention_patcher=unet_attention_patcher,
)
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
@ -471,7 +458,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@ -486,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
weight = (
single_ip_adapter_data.weight[step_index]
if isinstance(single_ip_adapter_data.weight, List)
else single_ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
unet_attention_patcher.set_scale(i, weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
unet_attention_patcher.set_scale(i, 0.0)
# Handle ControlNet(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
@ -545,17 +514,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_intrablock_additional_residuals = accum_adapter_state
ip_adapter_conditioning = None
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
ip_adapter_data=ip_adapter_data,
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter

View File

@ -1,8 +1,11 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
@dataclass
class BasicConditioningInfo:
@ -45,6 +48,27 @@ class IPAdapterConditioningInfo:
"""
@dataclass
class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0
def scale_for_step(self, step_index: int, total_steps: int) -> float:
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
return weight
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
return 0.0
@dataclass
class Range:
start: int

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@ -54,15 +55,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
# For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
regional_ip_data: Optional[RegionalIPData] = None,
) -> torch.FloatTensor:
"""Apply attention.
Args:
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
apply regional prompt masking.
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
regional_ip_data: The IP-Adapter data for the current batch.
"""
# If true, we are doing cross-attention, if false we are doing self-attention.
is_cross_attention = encoder_hidden_states is not None
@ -141,9 +140,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Apply IP-Adapter conditioning.
if is_cross_attention and self._is_ip_adapter_enabled():
if self._is_ip_adapter_enabled():
assert ip_adapter_image_prompt_embeds is not None
assert regional_ip_data is not None
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
@ -178,8 +177,8 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = hidden_states + scale * ip_hidden_states
else:
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
assert ip_adapter_image_prompt_embeds is None
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv

View File

@ -0,0 +1,20 @@
import torch
class RegionalIPData:
"""A class to manage the data for regional IP-Adapter conditioning."""
def __init__(
self,
image_prompt_embeds: list[torch.Tensor],
scales: list[float],
):
"""Initialize a `IPAdapterConditioningData` object."""
# The image prompt embeddings.
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
self.image_prompt_embeds = image_prompt_embeds
# The scales for the IP-Adapter attention.
# scales[i] contains the attention scale for the i'th IP-Adapter.
self.scales = scales

View File

@ -8,11 +8,12 @@ from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterConditioningInfo,
IPAdapterData,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
ModelForwardCallback: TypeAlias = Union[
@ -169,15 +170,13 @@ class InvokeAIDiffuserComponent:
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
percent_through = step_index / total_step_count
if self.sequential_guidance:
(
unconditioned_next_x,
@ -186,8 +185,9 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
ip_adapter_data=ip_adapter_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -200,8 +200,9 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
ip_adapter_data=ip_adapter_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -260,15 +261,16 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning(
self,
x,
sigma,
x: torch.Tensor,
sigma: torch.Tensor,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
@ -276,12 +278,16 @@ class InvokeAIDiffuserComponent:
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = {}
if ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
image_prompt_embeds = [
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in ip_adapter_conditioning
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
if conditioning_data.is_sdxl():
@ -326,7 +332,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
cross_attention_kwargs["percent_through"] = step_index / total_step_count
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
@ -350,8 +356,9 @@ class InvokeAIDiffuserComponent:
x: torch.Tensor,
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -388,13 +395,18 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
image_prompt_embeds = [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
if conditioning_data.is_sdxl():
@ -408,7 +420,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
@ -428,14 +440,18 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
image_prompt_embeds = [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
if conditioning_data.is_sdxl():
@ -449,7 +465,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(

View File

@ -17,9 +17,6 @@ class UNetAttentionPatcher:
if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float):
self._ip_adapter_scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them (if IP-Adapters are being applied).