mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove support for Prompt-to-Prompt cross-attention control (aka .swap()). This feature is not widely used. It does not work with SDXL and is incompatible with IP-Adapter and regional prompting. The implementation is also intertwined with both text embedding and the UNet attention layers, resulting in a high maintenance burden. For all of these reasons, we have decided to drop support.
This commit is contained in:
parent
fe386252f3
commit
4a828818da
@ -22,7 +22,6 @@ from invokeai.backend.model_patcher import ModelPatcher
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ConditioningFieldData,
|
ConditioningFieldData,
|
||||||
ExtraConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
@ -109,23 +108,11 @@ class CompelInvocation(BaseInvocation):
|
|||||||
if context.config.get().log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
ec = ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
c = c.detach().to("cpu")
|
c = c.detach().to("cpu")
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||||
conditionings=[
|
|
||||||
BasicConditioningInfo(
|
|
||||||
embeds=c,
|
|
||||||
extra_conditioning=ec,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
return ConditioningOutput(
|
return ConditioningOutput(
|
||||||
@ -147,7 +134,7 @@ class SDXLPromptInvocationBase:
|
|||||||
get_pooled: bool,
|
get_pooled: bool,
|
||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
@ -174,7 +161,7 @@ class SDXLPromptInvocationBase:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
return c, c_pooled, None
|
return c, c_pooled
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
@ -219,17 +206,12 @@ class SDXLPromptInvocationBase:
|
|||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
# TODO: ask for optimizations? to not run text_encoder twice
|
# TODO: ask for optimizations? to not run text_encoder twice
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
if get_pooled:
|
if get_pooled:
|
||||||
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
|
||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
|
|
||||||
ec = ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
del tokenizer
|
del tokenizer
|
||||||
del text_encoder
|
del text_encoder
|
||||||
del tokenizer_info
|
del tokenizer_info
|
||||||
@ -239,7 +221,7 @@ class SDXLPromptInvocationBase:
|
|||||||
if c_pooled is not None:
|
if c_pooled is not None:
|
||||||
c_pooled = c_pooled.detach().to("cpu")
|
c_pooled = c_pooled.detach().to("cpu")
|
||||||
|
|
||||||
return c, c_pooled, ec
|
return c, c_pooled
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -276,17 +258,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True)
|
||||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
|
||||||
)
|
|
||||||
if self.style.strip() == "":
|
if self.style.strip() == "":
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
c2, c2_pooled = self.run_clip_compel(
|
||||||
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(
|
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
|
||||||
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
|
|
||||||
)
|
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -325,10 +303,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
SDXLConditioningInfo(
|
SDXLConditioningInfo(
|
||||||
embeds=torch.cat([c1, c2], dim=-1),
|
embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
extra_conditioning=ec1,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -368,7 +343,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
# TODO: if there will appear lora for refiner - write proper prefix
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||||
|
|
||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
@ -377,14 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
assert c2_pooled is not None
|
assert c2_pooled is not None
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)]
|
||||||
SDXLConditioningInfo(
|
|
||||||
embeds=c2,
|
|
||||||
pooled_embeds=c2_pooled,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
extra_conditioning=ec2, # or None
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
@ -434,15 +434,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
cur_text_embedding_len = 0
|
cur_text_embedding_len = 0
|
||||||
processed_masks = []
|
processed_masks = []
|
||||||
embedding_ranges = []
|
embedding_ranges = []
|
||||||
extra_conditioning = None
|
|
||||||
|
|
||||||
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
mask = masks[prompt_idx]
|
mask = masks[prompt_idx]
|
||||||
if (
|
|
||||||
text_embedding_info.extra_conditioning is not None
|
|
||||||
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
|
||||||
):
|
|
||||||
extra_conditioning = text_embedding_info.extra_conditioning
|
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||||
@ -483,23 +477,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ranges=embedding_ranges,
|
ranges=embedding_ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
if extra_conditioning is not None and len(text_conditionings) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
|
||||||
"prompts."
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
return SDXLConditioningInfo(
|
return SDXLConditioningInfo(
|
||||||
embeds=text_embedding,
|
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
|
||||||
extra_conditioning=extra_conditioning,
|
|
||||||
pooled_embeds=pooled_embedding,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
), regions
|
|
||||||
return BasicConditioningInfo(
|
|
||||||
embeds=text_embedding,
|
|
||||||
extra_conditioning=extra_conditioning,
|
|
||||||
), regions
|
), regions
|
||||||
|
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
|
@ -401,31 +401,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
|
||||||
use_cross_attention_control = (
|
|
||||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
|
||||||
)
|
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
use_regional_prompting = (
|
use_regional_prompting = (
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
)
|
)
|
||||||
if use_cross_attention_control and use_ip_adapter:
|
|
||||||
raise ValueError(
|
|
||||||
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
|
|
||||||
)
|
|
||||||
if use_cross_attention_control and use_regional_prompting:
|
|
||||||
raise ValueError(
|
|
||||||
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
|
||||||
)
|
|
||||||
|
|
||||||
unet_attention_patcher = None
|
unet_attention_patcher = None
|
||||||
self.use_ip_adapter = use_ip_adapter
|
self.use_ip_adapter = use_ip_adapter
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
if use_cross_attention_control:
|
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
|
||||||
self.invokeai_diffuser.model,
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
|
||||||
)
|
|
||||||
if use_ip_adapter or use_regional_prompting:
|
if use_ip_adapter or use_regional_prompting:
|
||||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
|
@ -3,29 +3,12 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .cross_attention_control import Arguments
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExtraConditioningInfo:
|
|
||||||
"""Extra conditioning information produced by Compel.
|
|
||||||
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokens_count_including_eos_bos: int
|
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wants_cross_attention_control(self):
|
|
||||||
return self.cross_attention_control_args is not None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
"""SD 1/2 text conditioning information produced by Compel."""
|
"""SD 1/2 text conditioning information produced by Compel."""
|
||||||
|
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
def to(self, device, dtype=None):
|
||||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
|
@ -1,218 +0,0 @@
|
|||||||
# adapted from bloc97's CrossAttentionControl colab
|
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
|
||||||
|
|
||||||
|
|
||||||
import enum
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from compel.cross_attention_control import Arguments
|
|
||||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
|
||||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
|
||||||
SELF = 1
|
|
||||||
TOKENS = 2
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttnControlContext:
|
|
||||||
def __init__(self, arguments: Arguments):
|
|
||||||
"""
|
|
||||||
:param arguments: Arguments for the cross-attention control process
|
|
||||||
"""
|
|
||||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
|
||||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
|
||||||
self.arguments = arguments
|
|
||||||
|
|
||||||
def get_active_cross_attention_control_types_for_step(
|
|
||||||
self, percent_through: float = None
|
|
||||||
) -> list[CrossAttentionType]:
|
|
||||||
"""
|
|
||||||
Should cross-attention control be applied on the given step?
|
|
||||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
|
||||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
|
||||||
"""
|
|
||||||
if percent_through is None:
|
|
||||||
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
|
||||||
|
|
||||||
opts = self.arguments.edit_options
|
|
||||||
to_control = []
|
|
||||||
if opts["s_start"] <= percent_through < opts["s_end"]:
|
|
||||||
to_control.append(CrossAttentionType.SELF)
|
|
||||||
if opts["t_start"] <= percent_through < opts["t_end"]:
|
|
||||||
to_control.append(CrossAttentionType.TOKENS)
|
|
||||||
return to_control
|
|
||||||
|
|
||||||
|
|
||||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
|
||||||
"""
|
|
||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
|
||||||
|
|
||||||
:param model: The unet model to inject into.
|
|
||||||
:return: None
|
|
||||||
"""
|
|
||||||
|
|
||||||
# adapted from init_attention_edit
|
|
||||||
device = context.arguments.edited_conditioning.device
|
|
||||||
|
|
||||||
# urgh. should this be hardcoded?
|
|
||||||
max_length = 77
|
|
||||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
|
||||||
mask = torch.zeros(max_length, dtype=torch_dtype(device))
|
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
||||||
indices = torch.arange(max_length, dtype=torch.long)
|
|
||||||
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
|
||||||
if b0 < max_length:
|
|
||||||
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
|
|
||||||
# these tokens have not been edited
|
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
|
||||||
mask[b0:b1] = 1
|
|
||||||
|
|
||||||
context.cross_attention_mask = mask.to(device)
|
|
||||||
context.cross_attention_index_map = indices.to(device)
|
|
||||||
old_attn_processors = unet.attn_processors
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
|
||||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
|
||||||
else:
|
|
||||||
# try to re-use an existing slice size
|
|
||||||
default_slice_size = 4
|
|
||||||
slice_size = next(
|
|
||||||
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
|
|
||||||
)
|
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SwapCrossAttnContext:
|
|
||||||
modified_text_embeddings: torch.Tensor
|
|
||||||
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
|
|
||||||
mask: torch.Tensor # in the target space of the index_map
|
|
||||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
|
||||||
|
|
||||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
|
||||||
return attn_type in self.cross_attention_types_to_do
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def make_mask_and_index_map(
|
|
||||||
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
|
|
||||||
mask = torch.zeros(max_length)
|
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
||||||
indices = torch.arange(max_length, dtype=torch.long)
|
|
||||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
|
||||||
if b0 < max_length:
|
|
||||||
if name == "equal":
|
|
||||||
# these tokens remain the same as in the original prompt
|
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
|
||||||
mask[b0:b1] = 1
|
|
||||||
|
|
||||||
return mask, indices
|
|
||||||
|
|
||||||
|
|
||||||
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|
||||||
# TODO: dynamically pick slice size based on memory conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn: Attention,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
# kwargs
|
|
||||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
|
||||||
|
|
||||||
# if cross-attention control is not in play, just call through to the base implementation.
|
|
||||||
if (
|
|
||||||
attention_type is CrossAttentionType.SELF
|
|
||||||
or swap_cross_attn_context is None
|
|
||||||
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
|
|
||||||
):
|
|
||||||
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
|
||||||
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
|
||||||
# else:
|
|
||||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
|
||||||
attention_mask = attn.prepare_attention_mask(
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
target_length=sequence_length,
|
|
||||||
batch_size=batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
dim = query.shape[-1]
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
original_text_embeddings = encoder_hidden_states
|
|
||||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
|
||||||
original_text_key = attn.to_k(original_text_embeddings)
|
|
||||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
|
||||||
original_value = attn.to_v(original_text_embeddings)
|
|
||||||
modified_value = attn.to_v(modified_text_embeddings)
|
|
||||||
|
|
||||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
|
||||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
|
||||||
original_value = attn.head_to_batch_dim(original_value)
|
|
||||||
modified_value = attn.head_to_batch_dim(modified_value)
|
|
||||||
|
|
||||||
# compute slices and prepare output tensor
|
|
||||||
batch_size_attention = query.shape[0]
|
|
||||||
hidden_states = torch.zeros(
|
|
||||||
(batch_size_attention, sequence_length, dim // attn.heads),
|
|
||||||
device=query.device,
|
|
||||||
dtype=query.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# do slices
|
|
||||||
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
|
|
||||||
start_idx = i * self.slice_size
|
|
||||||
end_idx = (i + 1) * self.slice_size
|
|
||||||
|
|
||||||
query_slice = query[start_idx:end_idx]
|
|
||||||
original_key_slice = original_text_key[start_idx:end_idx]
|
|
||||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
||||||
|
|
||||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
|
||||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
|
||||||
|
|
||||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
|
||||||
# the original attention probabilities must be remapped to account for token index changes in the
|
|
||||||
# modified prompt
|
|
||||||
remapped_original_attn_slice = torch.index_select(
|
|
||||||
original_attn_slice, -1, swap_cross_attn_context.index_map
|
|
||||||
)
|
|
||||||
|
|
||||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
|
||||||
mask = swap_cross_attn_context.mask
|
|
||||||
inverse_mask = 1 - mask
|
|
||||||
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
|
|
||||||
|
|
||||||
del remapped_original_attn_slice, modified_attn_slice
|
|
||||||
|
|
||||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
|
||||||
|
|
||||||
# done
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
|
|
||||||
def __init__(self):
|
|
||||||
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice
|
|
@ -1,16 +1,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ExtraConditioningInfo,
|
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
Range,
|
Range,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
@ -18,13 +15,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
|
||||||
CrossAttentionType,
|
|
||||||
CrossAttnControlContext,
|
|
||||||
SwapCrossAttnContext,
|
|
||||||
setup_cross_attention_control_attention_processors,
|
|
||||||
)
|
|
||||||
|
|
||||||
ModelForwardCallback: TypeAlias = Union[
|
ModelForwardCallback: TypeAlias = Union[
|
||||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||||
Callable[
|
Callable[
|
||||||
@ -61,31 +51,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
|
||||||
self.sequential_guidance = config.sequential_guidance
|
self.sequential_guidance = config.sequential_guidance
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def custom_attention_context(
|
|
||||||
self,
|
|
||||||
unet: UNet2DConditionModel,
|
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
|
||||||
):
|
|
||||||
old_attn_processors = unet.attn_processors
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.cross_attention_control_context = CrossAttnControlContext(
|
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
|
||||||
)
|
|
||||||
setup_cross_attention_control_attention_processors(
|
|
||||||
unet,
|
|
||||||
self.cross_attention_control_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
self.cross_attention_control_context = None
|
|
||||||
unet.set_attn_processor(old_attn_processors)
|
|
||||||
|
|
||||||
def do_controlnet_step(
|
def do_controlnet_step(
|
||||||
self,
|
self,
|
||||||
control_data,
|
control_data,
|
||||||
@ -210,16 +177,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
percent_through = step_index / total_step_count
|
percent_through = step_index / total_step_count
|
||||||
cross_attention_control_types_to_do = []
|
|
||||||
if self.cross_attention_control_context is not None:
|
|
||||||
cross_attention_control_types_to_do = (
|
|
||||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
|
||||||
)
|
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
|
||||||
|
|
||||||
if wants_cross_attention_control or self.sequential_guidance:
|
if self.sequential_guidance:
|
||||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
|
||||||
# control is currently only supported in sequential mode.
|
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
@ -229,7 +188,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
percent_through=percent_through,
|
percent_through=percent_through,
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@ -394,7 +352,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
percent_through: float,
|
percent_through: float,
|
||||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -424,19 +381,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
|
||||||
cross_attn_processor_context = None
|
|
||||||
if self.cross_attention_control_context is not None:
|
|
||||||
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
|
||||||
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
|
||||||
# will be populated before the conditioned pass.
|
|
||||||
cross_attn_processor_context = SwapCrossAttnContext(
|
|
||||||
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
|
||||||
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
|
||||||
mask=self.cross_attention_control_context.cross_attention_mask,
|
|
||||||
cross_attention_types_to_do=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
#####################
|
#####################
|
||||||
# Unconditioned pass
|
# Unconditioned pass
|
||||||
#####################
|
#####################
|
||||||
@ -451,10 +395,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
|
||||||
if cross_attn_processor_context is not None:
|
|
||||||
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if conditioning_data.is_sdxl():
|
if conditioning_data.is_sdxl():
|
||||||
@ -496,11 +436,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
|
||||||
if cross_attn_processor_context is not None:
|
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
|
||||||
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if conditioning_data.is_sdxl():
|
if conditioning_data.is_sdxl():
|
||||||
|
@ -76,7 +76,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ConditioningFieldData,
|
ConditioningFieldData,
|
||||||
ExtraConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
|
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
|
||||||
@ -151,7 +150,6 @@ __all__ = [
|
|||||||
# invokeai.backend.stable_diffusion.diffusion.conditioning_data
|
# invokeai.backend.stable_diffusion.diffusion.conditioning_data
|
||||||
"BasicConditioningInfo",
|
"BasicConditioningInfo",
|
||||||
"ConditioningFieldData",
|
"ConditioningFieldData",
|
||||||
"ExtraConditioningInfo",
|
|
||||||
"SDXLConditioningInfo",
|
"SDXLConditioningInfo",
|
||||||
# invokeai.backend.stable_diffusion.diffusers_pipeline
|
# invokeai.backend.stable_diffusion.diffusers_pipeline
|
||||||
"PipelineIntermediateState",
|
"PipelineIntermediateState",
|
||||||
|
@ -256,7 +256,6 @@ module = [
|
|||||||
"invokeai.backend.model_management.seamless",
|
"invokeai.backend.model_management.seamless",
|
||||||
"invokeai.backend.model_management.util",
|
"invokeai.backend.model_management.util",
|
||||||
"invokeai.backend.stable_diffusion.diffusers_pipeline",
|
"invokeai.backend.stable_diffusion.diffusers_pipeline",
|
||||||
"invokeai.backend.stable_diffusion.diffusion.cross_attention_control",
|
|
||||||
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
|
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
|
||||||
"invokeai.backend.util.hotfixes",
|
"invokeai.backend.util.hotfixes",
|
||||||
"invokeai.backend.util.mps_fixes",
|
"invokeai.backend.util.mps_fixes",
|
||||||
|
Loading…
Reference in New Issue
Block a user