mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into pr/6086
This commit is contained in:
@ -21,12 +21,11 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import normalize_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -149,16 +148,6 @@ class ControlNetData:
|
||||
resize_mode: str = Field(default="just_resize")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
# weight: float = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class T2IAdapterData:
|
||||
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
|
||||
@ -266,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
@ -295,7 +284,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
@ -308,7 +298,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
@ -326,20 +316,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
if mask is not None:
|
||||
# if no noise provided, noisify unmasked area based on seed(or 0 as fallback)
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed or 0),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
@ -348,6 +324,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
@ -355,6 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
@ -380,7 +366,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
@ -397,22 +384,22 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters: Optional[List[UNetIPAdapterData]] = (
|
||||
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
|
||||
if use_ip_adapter
|
||||
else None
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
self.use_ip_adapter = True
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
@ -435,11 +422,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
@ -463,14 +450,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@ -485,23 +472,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
|
||||
|
||||
# handle IP-Adapter
|
||||
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
|
||||
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
|
||||
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
|
||||
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
|
||||
weight = (
|
||||
single_ip_adapter_data.weight[step_index]
|
||||
if isinstance(single_ip_adapter_data.weight, List)
|
||||
else single_ip_adapter_data.weight
|
||||
)
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
||||
else:
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
||||
|
||||
# Handle ControlNet(s)
|
||||
down_block_additional_residuals = None
|
||||
mid_block_additional_residual = None
|
||||
@ -550,6 +520,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
@ -569,7 +540,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
||||
for guidance in additional_guidance:
|
||||
|
@ -1,27 +1,17 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .cross_attention_control import Arguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
"""SD 1/2 text conditioning information produced by Compel."""
|
||||
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
@ -35,6 +25,8 @@ class ConditioningFieldData:
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
@ -57,37 +49,75 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
mask: torch.Tensor
|
||||
target_blocks: List[str]
|
||||
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
begin_step_percent: float = 0.0
|
||||
end_step_percent: float = 1.0
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
def scale_for_step(self, step_index: int, total_steps: int) -> float:
|
||||
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
|
||||
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
|
||||
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
return weight
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
return 0.0
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
class TextConditioningRegions:
|
||||
def __init__(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
ranges: list[Range],
|
||||
):
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
# Dtype: torch.bool
|
||||
self.masks = masks
|
||||
|
||||
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||
self.ranges = ranges
|
||||
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
@ -1,218 +0,0 @@
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
|
||||
class CrossAttentionType(enum.Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class CrossAttnControlContext:
|
||||
def __init__(self, arguments: Arguments):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
"""
|
||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||
self.arguments = arguments
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
) -> list[CrossAttentionType]:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
||||
|
||||
opts = self.arguments.edit_options
|
||||
to_control = []
|
||||
if opts["s_start"] <= percent_through < opts["s_end"]:
|
||||
to_control.append(CrossAttentionType.SELF)
|
||||
if opts["t_start"] <= percent_through < opts["t_end"]:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = context.arguments.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length, dtype=torch_dtype(device))
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next(
|
||||
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
|
||||
)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||
return attn_type in self.cross_attention_types_to_do
|
||||
|
||||
@classmethod
|
||||
def make_mask_and_index_map(
|
||||
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.arange(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":
|
||||
# these tokens remain the same as in the original prompt
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
return mask, indices
|
||||
|
||||
|
||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# TODO: dynamically pick slice size based on memory conditions
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||
**kwargs,
|
||||
):
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
# if cross-attention control is not in play, just call through to the base implementation.
|
||||
if (
|
||||
attention_type is CrossAttentionType.SELF
|
||||
or swap_cross_attn_context is None
|
||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
||||
):
|
||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
# else:
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask=attention_mask,
|
||||
target_length=sequence_length,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
original_text_embeddings = encoder_hidden_states
|
||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||
original_text_key = attn.to_k(original_text_embeddings)
|
||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||
original_value = attn.to_v(original_text_embeddings)
|
||||
modified_value = attn.to_v(modified_text_embeddings)
|
||||
|
||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||
original_value = attn.head_to_batch_dim(original_value)
|
||||
modified_value = attn.head_to_batch_dim(modified_value)
|
||||
|
||||
# compute slices and prepare output tensor
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads),
|
||||
device=query.device,
|
||||
dtype=query.dtype,
|
||||
)
|
||||
|
||||
# do slices
|
||||
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
original_key_slice = original_text_key[start_idx:end_idx]
|
||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
# modified prompt
|
||||
remapped_original_attn_slice = torch.index_select(
|
||||
original_attn_slice, -1, swap_cross_attn_context.index_map
|
||||
)
|
||||
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
||||
|
||||
del remapped_original_attn_slice, modified_attn_slice
|
||||
|
||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
# done
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
||||
def __init__(self):
|
||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
214
invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
Normal file
214
invokeai/backend/stable_diffusion/diffusion/custom_atttention.py
Normal file
@ -0,0 +1,214 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterAttentionWeights:
|
||||
ip_adapter_weights: IPAttentionProcessorWeights
|
||||
skip: bool
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
This implementation is based on
|
||||
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||
Supported custom features:
|
||||
- IP-Adapter
|
||||
- Regional prompt attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
layer-specific are passed to __init__().
|
||||
Args:
|
||||
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||
for the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
self._ip_adapter_attention_weights = ip_adapter_attention_weights
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
# For Regional Prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.Tensor] = None,
|
||||
# For IP-Adapter:
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||
apply regional prompt masking.
|
||||
regional_ip_data: The IP-Adapter data for the current batch.
|
||||
"""
|
||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
# Handle regional prompt attention masks.
|
||||
if regional_prompt_data is not None and is_cross_attention:
|
||||
assert percent_through is not None
|
||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention:
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_attention_weights)
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
ip_mask = ip_masks[0, ipa_index, ...]
|
||||
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End of unmodified block from AttnProcessor2_0
|
||||
|
||||
# casting torch.Tensor to torch.FloatTensor to avoid type issues
|
||||
return cast(torch.FloatTensor, hidden_states)
|
@ -0,0 +1,72 @@
|
||||
import torch
|
||||
|
||||
|
||||
class RegionalIPData:
|
||||
"""A class to manage the data for regional IP-Adapter conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_prompt_embeds: list[torch.Tensor],
|
||||
scales: list[float],
|
||||
masks: list[torch.Tensor],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `IPAdapterConditioningData` object."""
|
||||
assert len(image_prompt_embeds) == len(scales) == len(masks)
|
||||
|
||||
# The image prompt embeddings.
|
||||
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
||||
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
self.image_prompt_embeds = image_prompt_embeds
|
||||
|
||||
# The scales for the IP-Adapter attention.
|
||||
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
||||
self.scales = scales
|
||||
|
||||
# The IP-Adapter masks.
|
||||
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
|
||||
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
|
||||
# regions and 0.0 for excluded regions.
|
||||
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
|
||||
|
||||
def _prepare_masks(
|
||||
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
|
||||
) -> dict[int, torch.Tensor]:
|
||||
"""Prepare the masks for the IP-Adapter attention."""
|
||||
# Concatenate the masks so that they can be processed more efficiently.
|
||||
mask_tensor = torch.cat(masks, dim=1)
|
||||
|
||||
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
masks_by_seq_len: dict[int, torch.Tensor] = {}
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, num_ip_adapters, h, w = mask_tensor.shape
|
||||
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
|
||||
assert b == 1
|
||||
|
||||
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
|
||||
# the spatial features.
|
||||
query_seq_len = h * w
|
||||
|
||||
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
|
||||
# regions to be lost entirely.
|
||||
#
|
||||
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
||||
#
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
|
||||
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2, ceil_mode=True)
|
||||
|
||||
return masks_by_seq_len
|
||||
|
||||
def get_masks(self, query_seq_len: int) -> torch.Tensor:
|
||||
"""Get the mask for the given query sequence length."""
|
||||
return self._masks_by_seq_len[query_seq_len]
|
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
"""A class to manage the prompt data for regional conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regions: list[TextConditioningRegions],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `RegionalPromptData` object.
|
||||
Args:
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
device (torch.device): The device to use for the attention masks.
|
||||
dtype (torch.dtype): The data type to use for the attention masks.
|
||||
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||
in steps of 2x.
|
||||
"""
|
||||
self._regions = regions
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||
# sequence length of s.
|
||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
self._negative_cross_attn_mask_score = -10000.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
"""Prepare the spatial masks for all downscaling factors."""
|
||||
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||
# of s.
|
||||
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
|
||||
for batch_sample_regions in regions:
|
||||
batch_sample_masks_by_seq_len.append({})
|
||||
|
||||
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||
assert b == 1
|
||||
query_seq_len = h * w
|
||||
|
||||
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||
# regions to be lost entirely.
|
||||
#
|
||||
# ceil_mode=True is set to mirror the downsampling behavior of SD and SDXL.
|
||||
#
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods (e.g.
|
||||
# nearest interpolation), and could potentially use a weighted mask rather than a binary mask.
|
||||
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2, ceil_mode=True)
|
||||
|
||||
return batch_sample_masks_by_seq_len
|
||||
|
||||
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||
"""Get the cross-attention mask for the given query sequence length.
|
||||
Args:
|
||||
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||
Returns:
|
||||
torch.Tensor: The cross-attention score mask.
|
||||
shape: (batch_size, query_seq_len, key_seq_len).
|
||||
dtype: float
|
||||
"""
|
||||
batch_size = len(self._spatial_masks_by_seq_len)
|
||||
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||
|
||||
# Create an empty attention mask with the correct shape.
|
||||
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||
batch_sample_regions = self._regions[batch_idx]
|
||||
|
||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :].clone()
|
||||
batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||
batch_sample_query_scores[batch_sample_query_mask] = 0.0
|
||||
batch_sample_query_scores[~batch_sample_query_mask] = self._negative_cross_attn_mask_score
|
||||
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores
|
||||
|
||||
return attn_mask
|
@ -1,26 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
CrossAttnControlContext,
|
||||
SwapCrossAttnContext,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
IPAdapterData,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@ -58,31 +52,8 @@ class InvokeAIDiffuserComponent:
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = CrossAttnControlContext(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
control_data,
|
||||
@ -90,7 +61,7 @@ class InvokeAIDiffuserComponent:
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
conditioning_data,
|
||||
conditioning_data: TextConditioningData,
|
||||
):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
@ -123,28 +94,28 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
@ -153,8 +124,8 @@ class InvokeAIDiffuserComponent:
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
conditioning_data.cond_text.embeds,
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
@ -198,24 +169,15 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
cross_attention_control_types_to_do = []
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
)
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||
# control is currently only supported in sequential mode.
|
||||
if self.sequential_guidance:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
@ -223,7 +185,9 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@ -236,6 +200,9 @@ class InvokeAIDiffuserComponent:
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
@ -294,53 +261,84 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
x: torch.Tensor,
|
||||
sigma: torch.Tensor,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {}
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.stack(
|
||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||
)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
image_prompt_embeds = [
|
||||
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
regions = []
|
||||
for c, r in [
|
||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
@ -360,8 +358,10 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@ -391,53 +391,48 @@ class InvokeAIDiffuserComponent:
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
||||
cross_attn_processor_context = None
|
||||
if self.cross_attention_control_context is not None:
|
||||
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
||||
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
||||
# will be populated before the conditioned pass.
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
||||
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
||||
mask=self.cross_attention_control_context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
#####################
|
||||
# Unconditioned pass
|
||||
#####################
|
||||
|
||||
cross_attention_kwargs = None
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
image_prompt_embeds = [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
@ -449,36 +444,43 @@ class InvokeAIDiffuserComponent:
|
||||
# Conditioned pass
|
||||
###################
|
||||
|
||||
cross_attention_kwargs = None
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
image_prompt_embeds = [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
conditioning_data.cond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
|
@ -0,0 +1,68 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Optional, TypedDict
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
|
||||
CustomAttnProcessor2_0,
|
||||
IPAdapterAttentionWeights,
|
||||
)
|
||||
|
||||
|
||||
class UNetIPAdapterData(TypedDict):
|
||||
ip_adapter: IPAdapter
|
||||
target_blocks: List[str]
|
||||
|
||||
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
|
||||
def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]):
|
||||
self._ip_adapters = ip_adapter_data
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
weights into them (if IP-Adapters are being applied).
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
"""
|
||||
# Construct a dict of attention processors based on the UNet's architecture.
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||
# "attn1" processors do not use IP-Adapters.
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
|
||||
|
||||
for ip_adapter in self._ip_adapters:
|
||||
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||
skip = True
|
||||
for block in ip_adapter["target_blocks"]:
|
||||
if block in name:
|
||||
skip = False
|
||||
break
|
||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||
ip_adapter_weights=ip_adapter_weights, skip=skip
|
||||
)
|
||||
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
|
||||
|
||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)
|
||||
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||
attn_procs = self._prepare_attention_processors(unet)
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
unet.set_attn_processor(attn_procs)
|
||||
yield None
|
||||
finally:
|
||||
unet.set_attn_processor(orig_attn_processors)
|
Reference in New Issue
Block a user