Merge branch 'main' into pr/6086

This commit is contained in:
blessedcoolant
2024-05-01 00:37:06 +05:30
338 changed files with 11169 additions and 3081 deletions

View File

@ -21,12 +21,11 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import normalize_device
from invokeai.backend.util.devices import TorchDevice
@dataclass
@ -149,16 +148,6 @@ class ControlNetData:
resize_mode: str = Field(default="just_resize")
@dataclass
class IPAdapterData:
ip_adapter_model: IPAdapter = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
weight: Union[float, List[float]] = Field(default=1.0)
# weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass
class T2IAdapterData:
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
@ -266,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
@ -295,7 +284,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self,
latents: torch.Tensor,
num_inference_steps: int,
conditioning_data: ConditioningData,
scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*,
noise: Optional[torch.Tensor],
timesteps: torch.Tensor,
@ -308,7 +298,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
gradient_mask: Optional[bool] = False,
seed: Optional[int] = None,
seed: int,
) -> torch.Tensor:
if init_timestep.shape[0] == 0:
return latents
@ -326,20 +316,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents = self.scheduler.add_noise(latents, noise, batched_t)
if mask is not None:
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
latents = self.scheduler.add_noise(latents, noise, batched_t)
latents = torch.lerp(
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
)
if is_inpainting_model(self.unet):
if masked_latents is None:
raise Exception("Source image required for inpaint mask when inpaint model used!")
@ -348,6 +324,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self._unet_forward, mask, masked_latents
)
else:
# if no noise provided, noisify unmasked area based on seed
if noise is None:
noise = torch.randn(
orig_latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(seed),
).to(device=orig_latents.device, dtype=orig_latents.dtype)
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
try:
@ -355,6 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents,
timesteps,
conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
@ -380,7 +366,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self,
latents: torch.Tensor,
timesteps,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
@ -397,22 +384,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents
ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_ip_adapter or use_regional_prompting:
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
if use_ip_adapter
else None
)
self.use_ip_adapter = False
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
with attn_ctx:
if callback is not None:
@ -435,11 +422,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data,
step_index=i,
total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
)
latents = step_output.prev_sample
predicted_original = getattr(step_output, "pred_original_sample", None)
@ -463,14 +450,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self,
t: torch.Tensor,
latents: torch.Tensor,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@ -485,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
weight = (
single_ip_adapter_data.weight[step_index]
if isinstance(single_ip_adapter_data.weight, List)
else single_ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_unet_patcher.set_scale(i, weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0)
# Handle ControlNet(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
@ -550,6 +520,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
step_index=step_index,
total_step_count=total_step_count,
conditioning_data=conditioning_data,
ip_adapter_data=ip_adapter_data,
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
@ -569,7 +540,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
# compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
for guidance in additional_guidance:

View File

@ -1,27 +1,17 @@
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import math
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
@dataclass
class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo]
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
@ -35,6 +25,8 @@ class ConditioningFieldData:
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
@ -57,37 +49,75 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
scheduler_args: dict[str, Any] = field(default_factory=dict)
class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0
@property
def dtype(self):
return self.text_embeddings.dtype
def scale_for_step(self, step_index: int, total_steps: int) -> float:
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
return weight
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
return 0.0
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass
class Range:
start: int
end: int
class TextConditioningRegions:
def __init__(
self,
masks: torch.Tensor,
ranges: list[Range],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
self.masks = masks
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
class TextConditioningData:
def __init__(
self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

View File

@ -1,218 +0,0 @@
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
import enum
from dataclasses import dataclass, field
from typing import Optional
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.backend.util.devices import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
class CrossAttnControlContext:
def __init__(self, arguments: Arguments):
"""
:param arguments: Arguments for the cross-attention control process
"""
self.cross_attention_mask: Optional[torch.Tensor] = None
self.cross_attention_index_map: Optional[torch.Tensor] = None
self.arguments = arguments
def get_active_cross_attention_control_types_for_step(
self, percent_through: float = None
) -> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts["s_start"] <= percent_through < opts["s_end"]:
to_control.append(CrossAttentionType.SELF)
if opts["t_start"] <= percent_through < opts["t_end"]:
to_control.append(CrossAttentionType.TOKENS)
return to_control
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:return: None
"""
# adapted from init_attention_edit
device = context.arguments.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length, dtype=torch_dtype(device))
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
@classmethod
def make_mask_and_index_map(
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
) -> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":
# these tokens remain the same as in the original prompt
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
return mask, indices
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# TODO: dynamically pick slice size based on memory conditions
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if (
attention_type is CrossAttentionType.SELF
or swap_cross_attn_context is None
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
):
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
# else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(
attention_mask=attention_mask,
target_length=sequence_length,
batch_size=batch_size,
)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
original_text_embeddings = encoder_hidden_states
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
original_text_key = attn.to_k(original_text_embeddings)
modified_text_key = attn.to_k(modified_text_embeddings)
original_value = attn.to_v(original_text_embeddings)
modified_value = attn.to_v(modified_text_embeddings)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
original_value = attn.head_to_batch_dim(original_value)
modified_value = attn.head_to_batch_dim(modified_value)
# compute slices and prepare output tensor
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads),
device=query.device,
dtype=query.dtype,
)
# do slices
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attn_slice = torch.index_select(
original_attn_slice, -1, swap_cross_attn_context.index_map
)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# done
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice

View File

@ -0,0 +1,214 @@
from dataclasses import dataclass
from typing import List, Optional, cast
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
Supported custom features:
- IP-Adapter
- Regional prompt attention
"""
def __init__(
self,
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
layer-specific are passed to __init__().
Args:
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
for the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_attention_weights = ip_adapter_attention_weights
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
# For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
"""Apply attention.
Args:
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
apply regional prompt masking.
regional_ip_data: The IP-Adapter data for the current batch.
"""
# If true, we are doing cross-attention, if false we are doing self-attention.
is_cross_attention = encoder_hidden_states is not None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape
# Handle regional prompt attention masks.
if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
# Apply IP-Adapter conditioning.
if is_cross_attention:
if self._ip_adapter_attention_weights:
assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert (
len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_attention_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if not self._ip_adapter_attention_weights[ipa_index].skip:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
# casting torch.Tensor to torch.FloatTensor to avoid type issues
return cast(torch.FloatTensor, hidden_states)

View File

@ -0,0 +1,72 @@
import torch
class RegionalIPData:
"""A class to manage the data for regional IP-Adapter conditioning."""
def __init__(
self,
image_prompt_embeds: list[torch.Tensor],
scales: list[float],
masks: list[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
max_downscale_factor: int = 8,
):
"""Initialize a `IPAdapterConditioningData` object."""
assert len(image_prompt_embeds) == len(scales) == len(masks)
# The image prompt embeddings.
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
self.image_prompt_embeds = image_prompt_embeds
# The scales for the IP-Adapter attention.
# scales[i] contains the attention scale for the i'th IP-Adapter.
self.scales = scales
# The IP-Adapter masks.
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
# regions and 0.0 for excluded regions.
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
def _prepare_masks(
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
) -> dict[int, torch.Tensor]:
"""Prepare the masks for the IP-Adapter attention."""
# Concatenate the masks so that they can be processed more efficiently.
mask_tensor = torch.cat(masks, dim=1)
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
masks_by_seq_len: dict[int, torch.Tensor] = {}
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, num_ip_adapters, h, w = mask_tensor.shape
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
assert b == 1
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
# the spatial features.
query_seq_len = h * w
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
return masks_by_seq_len
def get_masks(self, query_seq_len: int) -> torch.Tensor:
"""Get the mask for the given query sequence length."""
return self._masks_by_seq_len[query_seq_len]

View File

@ -0,0 +1,105 @@
import torch
import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:
"""A class to manage the prompt data for regional conditioning."""
def __init__(
self,
regions: list[TextConditioningRegions],
device: torch.device,
dtype: torch.dtype,
max_downscale_factor: int = 8,
):
"""Initialize a `RegionalPromptData` object.
Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch.
device (torch.device): The device to use for the attention masks.
dtype (torch.dtype): The data type to use for the attention masks.
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
in steps of 2x.
"""
self._regions = regions
self._device = device
self._dtype = dtype
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor
)
self._negative_cross_attn_mask_score = -10000.0
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors."""
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
# of s.
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_sample_regions in regions:
batch_sample_masks_by_seq_len.append({})
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, _num_prompts, h, w = batch_sample_masks.shape
assert b == 1
query_seq_len = h * w
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# regions to be lost entirely.
#
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
#
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)
return batch_sample_masks_by_seq_len
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
"""Get the cross-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
Returns:
torch.Tensor: The cross-attention score mask.
shape: (batch_size, query_seq_len, key_seq_len).
dtype: float
"""
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
batch_sample_query_mask = batch_sample_query_scores > 0.5
batch_sample_query_scores[batch_sample_query_mask] = 0.0
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
return attn_mask

View File

@ -1,26 +1,20 @@
from __future__ import annotations
import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from .cross_attention_control import (
CrossAttentionType,
CrossAttnControlContext,
SwapCrossAttnContext,
setup_cross_attention_control_attention_processors,
IPAdapterData,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs]
@ -58,31 +52,8 @@ class InvokeAIDiffuserComponent:
self.conditioning = None
self.model = model
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
):
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = CrossAttnControlContext(
arguments=extra_conditioning_info.cross_attention_control_args,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors)
def do_controlnet_step(
self,
control_data,
@ -90,7 +61,7 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
conditioning_data,
conditioning_data: TextConditioningData,
):
down_block_res_samples, mid_block_res_sample = None, None
@ -123,28 +94,28 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
"text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.cond_text.add_time_ids,
}
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_hidden_states = conditioning_data.cond_text.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
@ -153,8 +124,8 @@ class InvokeAIDiffuserComponent:
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
conditioning_data.uncond_text.embeds,
conditioning_data.cond_text.embeds,
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
@ -198,24 +169,15 @@ class InvokeAIDiffuserComponent:
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control or self.sequential_guidance:
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
# control is currently only supported in sequential mode.
if self.sequential_guidance:
(
unconditioned_next_x,
conditioned_next_x,
@ -223,7 +185,9 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
ip_adapter_data=ip_adapter_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -236,6 +200,9 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_data=ip_adapter_data,
step_index=step_index,
total_step_count=total_step_count,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -294,53 +261,84 @@ class InvokeAIDiffuserComponent:
def _apply_standard_conditioning(
self,
x,
sigma,
conditioning_data: ConditioningData,
x: torch.Tensor,
sigma: torch.Tensor,
conditioning_data: TextConditioningData,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
) -> tuple[torch.Tensor, torch.Tensor]:
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {}
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
image_prompt_embeds = [
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in ip_adapter_conditioning
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
conditioning_data.uncond_text.add_time_ids,
conditioning_data.cond_text.add_time_ids,
],
dim=0,
),
}
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = step_index / total_step_count
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
)
both_results = self.model_forward_callback(
x_twice,
@ -360,8 +358,10 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data: ConditioningData,
cross_attention_control_types_to_do: list[CrossAttentionType],
conditioning_data: TextConditioningData,
ip_adapter_data: Optional[list[IPAdapterData]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -391,53 +391,48 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
cross_attn_processor_context = None
if self.cross_attention_control_context is not None:
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
# will be populated before the conditioned pass.
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
index_map=self.cross_attention_control_context.cross_attention_index_map,
mask=self.cross_attention_control_context.cross_attention_mask,
cross_attention_types_to_do=[],
)
#####################
# Unconditioned pass
#####################
cross_attention_kwargs = None
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
image_prompt_embeds = [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
"time_ids": conditioning_data.uncond_text.add_time_ids,
}
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.uncond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
@ -449,36 +444,43 @@ class InvokeAIDiffuserComponent:
# Conditioned pass
###################
cross_attention_kwargs = None
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
image_prompt_embeds = [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
if is_sdxl:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
"text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.cond_text.add_time_ids,
}
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.text_embeddings.embeds,
conditioning_data.cond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,

View File

@ -0,0 +1,68 @@
from contextlib import contextmanager
from typing import List, Optional, TypedDict
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
CustomAttnProcessor2_0,
IPAdapterAttentionWeights,
)
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
class UNetAttentionPatcher:
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
self._ip_adapters = ip_adapter_data
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them (if IP-Adapters are being applied).
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor") or self._ip_adapters is None:
# "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)