Remove support for Prompt-to-Prompt cross-attention control (aka .swap()). This feature is not widely used. It does not work with SDXL and is incompatible with IP-Adapter and regional prompting. The implementation is also intertwined with both text embedding and the UNet attention layers, resulting in a high maintenance burden. For all of these reasons, we have decided to drop support.

This commit is contained in:
Ryan Dick 2024-03-11 18:22:49 -04:00 committed by Kent Keirsey
parent fe386252f3
commit 4a828818da
8 changed files with 15 additions and 386 deletions

View File

@ -22,7 +22,6 @@ from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningFieldData, ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.util.devices import torch_dtype from invokeai.backend.util.devices import torch_dtype
@ -109,23 +108,11 @@ class CompelInvocation(BaseInvocation):
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
c = c.detach().to("cpu") c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
conditionings=[
BasicConditioningInfo(
embeds=c,
extra_conditioning=ec,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput( return ConditioningOutput(
@ -147,7 +134,7 @@ class SDXLPromptInvocationBase:
get_pooled: bool, get_pooled: bool,
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer) assert isinstance(tokenizer_model, CLIPTokenizer)
@ -174,7 +161,7 @@ class SDXLPromptInvocationBase:
) )
else: else:
c_pooled = None c_pooled = None
return c, c_pooled, None return c, c_pooled
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
@ -219,17 +206,12 @@ class SDXLPromptInvocationBase:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice # TODO: ask for optimizations? to not run text_encoder twice
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
if get_pooled: if get_pooled:
c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt]) c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt])
else: else:
c_pooled = None c_pooled = None
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
del tokenizer del tokenizer
del text_encoder del text_encoder
del tokenizer_info del tokenizer_info
@ -239,7 +221,7 @@ class SDXLPromptInvocationBase:
if c_pooled is not None: if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu") c_pooled = c_pooled.detach().to("cpu")
return c, c_pooled, ec return c, c_pooled
@invocation( @invocation(
@ -276,17 +258,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel( c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True)
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
)
if self.style.strip() == "": if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel( c2, c2_pooled = self.run_clip_compel(
context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True
) )
else: else:
c2, c2_pooled, ec2 = self.run_clip_compel( c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True)
context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True
)
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
@ -325,10 +303,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
embeds=torch.cat([c1, c2], dim=-1), embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec1,
) )
] ]
) )
@ -368,7 +343,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
# TODO: if there will appear lora for refiner - write proper prefix # TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False) c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
original_size = (self.original_height, self.original_width) original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left) crop_coords = (self.crop_top, self.crop_left)
@ -377,14 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
assert c2_pooled is not None assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)]
SDXLConditioningInfo(
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
)
]
) )
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)

View File

@ -434,15 +434,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
cur_text_embedding_len = 0 cur_text_embedding_len = 0
processed_masks = [] processed_masks = []
embedding_ranges = [] embedding_ranges = []
extra_conditioning = None
for prompt_idx, text_embedding_info in enumerate(text_conditionings): for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx] mask = masks[prompt_idx]
if (
text_embedding_info.extra_conditioning is not None
and text_embedding_info.extra_conditioning.wants_cross_attention_control
):
extra_conditioning = text_embedding_info.extra_conditioning
if is_sdxl: if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
@ -483,23 +477,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
ranges=embedding_ranges, ranges=embedding_ranges,
) )
if extra_conditioning is not None and len(text_conditionings) > 1:
raise ValueError(
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
"prompts."
)
if is_sdxl: if is_sdxl:
return SDXLConditioningInfo( return SDXLConditioningInfo(
embeds=text_embedding, embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
extra_conditioning=extra_conditioning,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regions ), regions
return BasicConditioningInfo( return BasicConditioningInfo(embeds=text_embedding), regions
embeds=text_embedding,
extra_conditioning=extra_conditioning,
), regions
def get_conditioning_data( def get_conditioning_data(
self, self,

View File

@ -401,31 +401,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents return latents
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = ( use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
) )
if use_cross_attention_control and use_ip_adapter:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
)
if use_cross_attention_control and use_regional_prompting:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
)
unet_attention_patcher = None unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext() attn_ctx = nullcontext()
if use_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
)
if use_ip_adapter or use_regional_prompting: if use_ip_adapter or use_regional_prompting:
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
unet_attention_patcher = UNetAttentionPatcher(ip_adapters) unet_attention_patcher = UNetAttentionPatcher(ip_adapters)

View File

@ -3,29 +3,12 @@ from typing import List, Optional, Union
import torch import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
"""Extra conditioning information produced by Compel.
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
"""
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel.""" """SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo]
def to(self, device, dtype=None): def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype) self.embeds = self.embeds.to(device=device, dtype=dtype)

View File

@ -1,218 +0,0 @@
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
import enum
from dataclasses import dataclass, field
from typing import Optional
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.backend.util.devices import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
class CrossAttnControlContext:
def __init__(self, arguments: Arguments):
"""
:param arguments: Arguments for the cross-attention control process
"""
self.cross_attention_mask: Optional[torch.Tensor] = None
self.cross_attention_index_map: Optional[torch.Tensor] = None
self.arguments = arguments
def get_active_cross_attention_control_types_for_step(
self, percent_through: float = None
) -> list[CrossAttentionType]:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
opts = self.arguments.edit_options
to_control = []
if opts["s_start"] <= percent_through < opts["s_end"]:
to_control.append(CrossAttentionType.SELF)
if opts["t_start"] <= percent_through < opts["t_end"]:
to_control.append(CrossAttentionType.TOKENS)
return to_control
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:return: None
"""
# adapted from init_attention_edit
device = context.arguments.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length, dtype=torch_dtype(device))
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
@dataclass
class SwapCrossAttnContext:
modified_text_embeddings: torch.Tensor
index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt
mask: torch.Tensor # in the target space of the index_map
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
return attn_type in self.cross_attention_types_to_do
@classmethod
def make_mask_and_index_map(
cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int
) -> tuple[torch.Tensor, torch.Tensor]:
# mask=1 means use original prompt attention, mask=0 means use modified prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":
# these tokens remain the same as in the original prompt
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
return mask, indices
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# TODO: dynamically pick slice size based on memory conditions
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if (
attention_type is CrossAttentionType.SELF
or swap_cross_attn_context is None
or not swap_cross_attn_context.wants_cross_attention_control(attention_type)
):
# print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
# else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(
attention_mask=attention_mask,
target_length=sequence_length,
batch_size=batch_size,
)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
original_text_embeddings = encoder_hidden_states
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
original_text_key = attn.to_k(original_text_embeddings)
modified_text_key = attn.to_k(modified_text_embeddings)
original_value = attn.to_v(original_text_embeddings)
modified_value = attn.to_v(modified_text_embeddings)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
original_value = attn.head_to_batch_dim(original_value)
modified_value = attn.head_to_batch_dim(modified_value)
# compute slices and prepare output tensor
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads),
device=query.device,
dtype=query.dtype,
)
# do slices
for i in range(max(1, hidden_states.shape[0] // self.slice_size)):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attn_slice = torch.index_select(
original_attn_slice, -1, swap_cross_attn_context.index_map
)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# done
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice

View File

@ -1,16 +1,13 @@
from __future__ import annotations from __future__ import annotations
import math import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
Range, Range,
TextConditioningData, TextConditioningData,
@ -18,13 +15,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
) )
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
from .cross_attention_control import (
CrossAttentionType,
CrossAttnControlContext,
SwapCrossAttnContext,
setup_cross_attention_control_attention_processors,
)
ModelForwardCallback: TypeAlias = Union[ ModelForwardCallback: TypeAlias = Union[
# x, t, conditioning, Optional[cross-attention kwargs] # x, t, conditioning, Optional[cross-attention kwargs]
Callable[ Callable[
@ -61,31 +51,8 @@ class InvokeAIDiffuserComponent:
self.conditioning = None self.conditioning = None
self.model = model self.model = model
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance self.sequential_guidance = config.sequential_guidance
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
):
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = CrossAttnControlContext(
arguments=extra_conditioning_info.cross_attention_control_args,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
unet.set_attn_processor(old_attn_processors)
def do_controlnet_step( def do_controlnet_step(
self, self,
control_data, control_data,
@ -210,16 +177,8 @@ class InvokeAIDiffuserComponent:
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
percent_through = step_index / total_step_count percent_through = step_index / total_step_count
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control or self.sequential_guidance: if self.sequential_guidance:
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
# control is currently only supported in sequential mode.
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
@ -229,7 +188,6 @@ class InvokeAIDiffuserComponent:
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning, ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through, percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals, down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -394,7 +352,6 @@ class InvokeAIDiffuserComponent:
conditioning_data: TextConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float, percent_through: float,
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -424,19 +381,6 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
cross_attn_processor_context = None
if self.cross_attention_control_context is not None:
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
# will be populated before the conditioned pass.
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
index_map=self.cross_attention_control_context.cross_attention_index_map,
mask=self.cross_attention_control_context.cross_attention_mask,
cross_attention_types_to_do=[],
)
##################### #####################
# Unconditioned pass # Unconditioned pass
##################### #####################
@ -451,10 +395,6 @@ class InvokeAIDiffuserComponent:
for ipa_conditioning in ip_adapter_conditioning for ipa_conditioning in ip_adapter_conditioning
] ]
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
# Prepare SDXL conditioning kwargs for the unconditioned pass. # Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if conditioning_data.is_sdxl(): if conditioning_data.is_sdxl():
@ -496,11 +436,6 @@ class InvokeAIDiffuserComponent:
for ipa_conditioning in ip_adapter_conditioning for ipa_conditioning in ip_adapter_conditioning
] ]
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
# Prepare SDXL conditioning kwargs for the conditioned pass. # Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if conditioning_data.is_sdxl(): if conditioning_data.is_sdxl():

View File

@ -76,7 +76,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningFieldData, ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device
@ -151,7 +150,6 @@ __all__ = [
# invokeai.backend.stable_diffusion.diffusion.conditioning_data # invokeai.backend.stable_diffusion.diffusion.conditioning_data
"BasicConditioningInfo", "BasicConditioningInfo",
"ConditioningFieldData", "ConditioningFieldData",
"ExtraConditioningInfo",
"SDXLConditioningInfo", "SDXLConditioningInfo",
# invokeai.backend.stable_diffusion.diffusers_pipeline # invokeai.backend.stable_diffusion.diffusers_pipeline
"PipelineIntermediateState", "PipelineIntermediateState",

View File

@ -256,7 +256,6 @@ module = [
"invokeai.backend.model_management.seamless", "invokeai.backend.model_management.seamless",
"invokeai.backend.model_management.util", "invokeai.backend.model_management.util",
"invokeai.backend.stable_diffusion.diffusers_pipeline", "invokeai.backend.stable_diffusion.diffusers_pipeline",
"invokeai.backend.stable_diffusion.diffusion.cross_attention_control",
"invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion", "invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion",
"invokeai.backend.util.hotfixes", "invokeai.backend.util.hotfixes",
"invokeai.backend.util.mps_fixes", "invokeai.backend.util.mps_fixes",