From 4a828818dab0794bce761fc509cbb7114fef1aaf Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 11 Mar 2024 18:22:49 -0400 Subject: [PATCH] Remove support for Prompt-to-Prompt cross-attention control (aka .swap()). This feature is not widely used. It does not work with SDXL and is incompatible with IP-Adapter and regional prompting. The implementation is also intertwined with both text embedding and the UNet attention layers, resulting in a high maintenance burden. For all of these reasons, we have decided to drop support. --- invokeai/app/invocations/compel.py | 56 +---- invokeai/app/invocations/latent.py | 22 +- .../stable_diffusion/diffusers_pipeline.py | 18 -- .../diffusion/conditioning_data.py | 17 -- .../diffusion/cross_attention_control.py | 218 ------------------ .../diffusion/shared_invokeai_diffusion.py | 67 +----- invokeai/invocation_api/__init__.py | 2 - pyproject.toml | 1 - 8 files changed, 15 insertions(+), 386 deletions(-) delete mode 100644 invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 92012691ea..481a2d2e4b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -22,7 +22,6 @@ from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, - ExtraConditioningInfo, SDXLConditioningInfo, ) from invokeai.backend.util.devices import torch_dtype @@ -109,23 +108,11 @@ class CompelInvocation(BaseInvocation): if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) - c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) - - ec = ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) + c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c = c.detach().to("cpu") - conditioning_data = ConditioningFieldData( - conditionings=[ - BasicConditioningInfo( - embeds=c, - extra_conditioning=ec, - ) - ] - ) + conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)]) conditioning_name = context.conditioning.save(conditioning_data) return ConditioningOutput( @@ -147,7 +134,7 @@ class SDXLPromptInvocationBase: get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_model = tokenizer_info.model assert isinstance(tokenizer_model, CLIPTokenizer) @@ -174,7 +161,7 @@ class SDXLPromptInvocationBase: ) else: c_pooled = None - return c, c_pooled, None + return c, c_pooled def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: @@ -219,17 +206,12 @@ class SDXLPromptInvocationBase: log_tokenization_for_conjunction(conjunction, tokenizer) # TODO: ask for optimizations? to not run text_encoder twice - c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) + c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) if get_pooled: c_pooled = compel.conditioning_provider.get_pooled_embeddings([prompt]) else: c_pooled = None - ec = ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) - del tokenizer del text_encoder del tokenizer_info @@ -239,7 +221,7 @@ class SDXLPromptInvocationBase: if c_pooled is not None: c_pooled = c_pooled.detach().to("cpu") - return c, c_pooled, ec + return c, c_pooled @invocation( @@ -276,17 +258,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - c1, c1_pooled, ec1 = self.run_clip_compel( - context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True - ) + c1, c1_pooled = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True) if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_compel( + c2, c2_pooled = self.run_clip_compel( context, self.clip2, self.prompt, True, "lora_te2_", zero_on_empty=True ) else: - c2, c2_pooled, ec2 = self.run_clip_compel( - context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True - ) + c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True) original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -325,10 +303,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( - embeds=torch.cat([c1, c2], dim=-1), - pooled_embeds=c2_pooled, - add_time_ids=add_time_ids, - extra_conditioning=ec1, + embeds=torch.cat([c1, c2], dim=-1), pooled_embeds=c2_pooled, add_time_ids=add_time_ids ) ] ) @@ -368,7 +343,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: # TODO: if there will appear lora for refiner - write proper prefix - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) + c2, c2_pooled = self.run_clip_compel(context, self.clip2, self.style, True, "", zero_on_empty=False) original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -377,14 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase assert c2_pooled is not None conditioning_data = ConditioningFieldData( - conditionings=[ - SDXLConditioningInfo( - embeds=c2, - pooled_embeds=c2_pooled, - add_time_ids=add_time_ids, - extra_conditioning=ec2, # or None - ) - ] + conditionings=[SDXLConditioningInfo(embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids)] ) conditioning_name = context.conditioning.save(conditioning_data) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index d5babe42cc..345621b38f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -434,15 +434,9 @@ class DenoiseLatentsInvocation(BaseInvocation): cur_text_embedding_len = 0 processed_masks = [] embedding_ranges = [] - extra_conditioning = None for prompt_idx, text_embedding_info in enumerate(text_conditionings): mask = masks[prompt_idx] - if ( - text_embedding_info.extra_conditioning is not None - and text_embedding_info.extra_conditioning.wants_cross_attention_control - ): - extra_conditioning = text_embedding_info.extra_conditioning if is_sdxl: # We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for @@ -483,23 +477,11 @@ class DenoiseLatentsInvocation(BaseInvocation): ranges=embedding_ranges, ) - if extra_conditioning is not None and len(text_conditionings) > 1: - raise ValueError( - "Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple " - "prompts." - ) - if is_sdxl: return SDXLConditioningInfo( - embeds=text_embedding, - extra_conditioning=extra_conditioning, - pooled_embeds=pooled_embedding, - add_time_ids=add_time_ids, + embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids ), regions - return BasicConditioningInfo( - embeds=text_embedding, - extra_conditioning=extra_conditioning, - ), regions + return BasicConditioningInfo(embeds=text_embedding), regions def get_conditioning_data( self, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 278a53eb0f..25acd32fea 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -401,31 +401,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if timesteps.shape[0] == 0: return latents - extra_conditioning_info = conditioning_data.cond_text.extra_conditioning - use_cross_attention_control = ( - extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control - ) use_ip_adapter = ip_adapter_data is not None use_regional_prompting = ( conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None ) - if use_cross_attention_control and use_ip_adapter: - raise ValueError( - "Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously." - ) - if use_cross_attention_control and use_regional_prompting: - raise ValueError( - "Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously." - ) - unet_attention_patcher = None self.use_ip_adapter = use_ip_adapter attn_ctx = nullcontext() - if use_cross_attention_control: - attn_ctx = self.invokeai_diffuser.custom_attention_context( - self.invokeai_diffuser.model, - extra_conditioning_info=extra_conditioning_info, - ) if use_ip_adapter or use_regional_prompting: ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None unet_attention_patcher = UNetAttentionPatcher(ip_adapters) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 6ef6d68fca..102de19428 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -3,29 +3,12 @@ from typing import List, Optional, Union import torch -from .cross_attention_control import Arguments - - -@dataclass -class ExtraConditioningInfo: - """Extra conditioning information produced by Compel. - This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel). - """ - - tokens_count_including_eos_bos: int - cross_attention_control_args: Optional[Arguments] = None - - @property - def wants_cross_attention_control(self): - return self.cross_attention_control_args is not None - @dataclass class BasicConditioningInfo: """SD 1/2 text conditioning information produced by Compel.""" embeds: torch.Tensor - extra_conditioning: Optional[ExtraConditioningInfo] def to(self, device, dtype=None): self.embeds = self.embeds.to(device=device, dtype=dtype) diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py deleted file mode 100644 index 2a0fcccd89..0000000000 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ /dev/null @@ -1,218 +0,0 @@ -# adapted from bloc97's CrossAttentionControl colab -# https://github.com/bloc97/CrossAttentionControl - - -import enum -from dataclasses import dataclass, field -from typing import Optional - -import torch -from compel.cross_attention_control import Arguments -from diffusers.models.attention_processor import Attention, SlicedAttnProcessor -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel - -from invokeai.backend.util.devices import torch_dtype - - -class CrossAttentionType(enum.Enum): - SELF = 1 - TOKENS = 2 - - -class CrossAttnControlContext: - def __init__(self, arguments: Arguments): - """ - :param arguments: Arguments for the cross-attention control process - """ - self.cross_attention_mask: Optional[torch.Tensor] = None - self.cross_attention_index_map: Optional[torch.Tensor] = None - self.arguments = arguments - - def get_active_cross_attention_control_types_for_step( - self, percent_through: float = None - ) -> list[CrossAttentionType]: - """ - Should cross-attention control be applied on the given step? - :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. - :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. - """ - if percent_through is None: - return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] - - opts = self.arguments.edit_options - to_control = [] - if opts["s_start"] <= percent_through < opts["s_end"]: - to_control.append(CrossAttentionType.SELF) - if opts["t_start"] <= percent_through < opts["t_end"]: - to_control.append(CrossAttentionType.TOKENS) - return to_control - - -def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext): - """ - Inject attention parameters and functions into the passed in model to enable cross attention editing. - - :param model: The unet model to inject into. - :return: None - """ - - # adapted from init_attention_edit - device = context.arguments.edited_conditioning.device - - # urgh. should this be hardcoded? - max_length = 77 - # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length, dtype=torch_dtype(device)) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.arange(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: - if b0 < max_length: - if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): - # these tokens have not been edited - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - context.cross_attention_mask = mask.to(device) - context.cross_attention_index_map = indices.to(device) - old_attn_processors = unet.attn_processors - if torch.backends.mps.is_available(): - # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS - unet.set_attn_processor(SwapCrossAttnProcessor()) - else: - # try to re-use an existing slice size - default_slice_size = 4 - slice_size = next( - (p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size - ) - unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) - - -@dataclass -class SwapCrossAttnContext: - modified_text_embeddings: torch.Tensor - index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt - mask: torch.Tensor # in the target space of the index_map - cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) - - def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: - return attn_type in self.cross_attention_types_to_do - - @classmethod - def make_mask_and_index_map( - cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int - ) -> tuple[torch.Tensor, torch.Tensor]: - # mask=1 means use original prompt attention, mask=0 means use modified prompt attention - mask = torch.zeros(max_length) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.arange(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in edit_opcodes: - if b0 < max_length: - if name == "equal": - # these tokens remain the same as in the original prompt - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - return mask, indices - - -class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): - # TODO: dynamically pick slice size based on memory conditions - - def __call__( - self, - attn: Attention, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - # kwargs - swap_cross_attn_context: SwapCrossAttnContext = None, - **kwargs, - ): - attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS - - # if cross-attention control is not in play, just call through to the base implementation. - if ( - attention_type is CrossAttentionType.SELF - or swap_cross_attn_context is None - or not swap_cross_attn_context.wants_cross_attention_control(attention_type) - ): - # print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - # else: - # print(f"SwapCrossAttnContext for {attention_type} active") - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask( - attention_mask=attention_mask, - target_length=sequence_length, - batch_size=batch_size, - ) - - query = attn.to_q(hidden_states) - dim = query.shape[-1] - query = attn.head_to_batch_dim(query) - - original_text_embeddings = encoder_hidden_states - modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings - original_text_key = attn.to_k(original_text_embeddings) - modified_text_key = attn.to_k(modified_text_embeddings) - original_value = attn.to_v(original_text_embeddings) - modified_value = attn.to_v(modified_text_embeddings) - - original_text_key = attn.head_to_batch_dim(original_text_key) - modified_text_key = attn.head_to_batch_dim(modified_text_key) - original_value = attn.head_to_batch_dim(original_value) - modified_value = attn.head_to_batch_dim(modified_value) - - # compute slices and prepare output tensor - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // attn.heads), - device=query.device, - dtype=query.dtype, - ) - - # do slices - for i in range(max(1, hidden_states.shape[0] // self.slice_size)): - start_idx = i * self.slice_size - end_idx = (i + 1) * self.slice_size - - query_slice = query[start_idx:end_idx] - original_key_slice = original_text_key[start_idx:end_idx] - modified_key_slice = modified_text_key[start_idx:end_idx] - attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - - original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) - modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attn_slice = torch.index_select( - original_attn_slice, -1, swap_cross_attn_context.index_map - ) - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = swap_cross_attn_context.mask - inverse_mask = 1 - mask - attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask - - del remapped_original_attn_slice, modified_attn_slice - - attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) - hidden_states[start_idx:end_idx] = attn_slice - - # done - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - -class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): - def __init__(self): - super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 4d95cb8f0d..f565f58352 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,16 +1,13 @@ from __future__ import annotations import math -from contextlib import contextmanager from typing import Any, Callable, Optional, Union import torch -from diffusers import UNet2DConditionModel from typing_extensions import TypeAlias from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - ExtraConditioningInfo, IPAdapterConditioningInfo, Range, TextConditioningData, @@ -18,13 +15,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ) from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData -from .cross_attention_control import ( - CrossAttentionType, - CrossAttnControlContext, - SwapCrossAttnContext, - setup_cross_attention_control_attention_processors, -) - ModelForwardCallback: TypeAlias = Union[ # x, t, conditioning, Optional[cross-attention kwargs] Callable[ @@ -61,31 +51,8 @@ class InvokeAIDiffuserComponent: self.conditioning = None self.model = model self.model_forward_callback = model_forward_callback - self.cross_attention_control_context = None self.sequential_guidance = config.sequential_guidance - @contextmanager - def custom_attention_context( - self, - unet: UNet2DConditionModel, - extra_conditioning_info: Optional[ExtraConditioningInfo], - ): - old_attn_processors = unet.attn_processors - - try: - self.cross_attention_control_context = CrossAttnControlContext( - arguments=extra_conditioning_info.cross_attention_control_args, - ) - setup_cross_attention_control_attention_processors( - unet, - self.cross_attention_control_context, - ) - - yield None - finally: - self.cross_attention_control_context = None - unet.set_attn_processor(old_attn_processors) - def do_controlnet_step( self, control_data, @@ -210,16 +177,8 @@ class InvokeAIDiffuserComponent: down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): percent_through = step_index / total_step_count - cross_attention_control_types_to_do = [] - if self.cross_attention_control_context is not None: - cross_attention_control_types_to_do = ( - self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) - ) - wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 - if wants_cross_attention_control or self.sequential_guidance: - # If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention - # control is currently only supported in sequential mode. + if self.sequential_guidance: ( unconditioned_next_x, conditioned_next_x, @@ -229,7 +188,6 @@ class InvokeAIDiffuserComponent: conditioning_data=conditioning_data, ip_adapter_conditioning=ip_adapter_conditioning, percent_through=percent_through, - cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -394,7 +352,6 @@ class InvokeAIDiffuserComponent: conditioning_data: TextConditioningData, ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], percent_through: float, - cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -424,19 +381,6 @@ class InvokeAIDiffuserComponent: if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) - # If cross-attention control is enabled, prepare the SwapCrossAttnContext. - cross_attn_processor_context = None - if self.cross_attention_control_context is not None: - # Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do. - # This list is empty because cross-attention control is not applied in the unconditioned pass. This field - # will be populated before the conditioned pass. - cross_attn_processor_context = SwapCrossAttnContext( - modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning, - index_map=self.cross_attention_control_context.cross_attention_index_map, - mask=self.cross_attention_control_context.cross_attention_mask, - cross_attention_types_to_do=[], - ) - ##################### # Unconditioned pass ##################### @@ -451,10 +395,6 @@ class InvokeAIDiffuserComponent: for ipa_conditioning in ip_adapter_conditioning ] - # Prepare cross-attention control kwargs for the unconditioned pass. - if cross_attn_processor_context is not None: - cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context - # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None if conditioning_data.is_sdxl(): @@ -496,11 +436,6 @@ class InvokeAIDiffuserComponent: for ipa_conditioning in ip_adapter_conditioning ] - # Prepare cross-attention control kwargs for the conditioned pass. - if cross_attn_processor_context is not None: - cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context - # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None if conditioning_data.is_sdxl(): diff --git a/invokeai/invocation_api/__init__.py b/invokeai/invocation_api/__init__.py index 11f334e24e..4eb78cf1ee 100644 --- a/invokeai/invocation_api/__init__.py +++ b/invokeai/invocation_api/__init__.py @@ -76,7 +76,6 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineInterme from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, - ExtraConditioningInfo, SDXLConditioningInfo, ) from invokeai.backend.util.devices import CPU_DEVICE, CUDA_DEVICE, MPS_DEVICE, choose_precision, choose_torch_device @@ -151,7 +150,6 @@ __all__ = [ # invokeai.backend.stable_diffusion.diffusion.conditioning_data "BasicConditioningInfo", "ConditioningFieldData", - "ExtraConditioningInfo", "SDXLConditioningInfo", # invokeai.backend.stable_diffusion.diffusers_pipeline "PipelineIntermediateState", diff --git a/pyproject.toml b/pyproject.toml index bca13c5109..5f924808a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -256,7 +256,6 @@ module = [ "invokeai.backend.model_management.seamless", "invokeai.backend.model_management.util", "invokeai.backend.stable_diffusion.diffusers_pipeline", - "invokeai.backend.stable_diffusion.diffusion.cross_attention_control", "invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion", "invokeai.backend.util.hotfixes", "invokeai.backend.util.mps_fixes",