diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 94ec9da7e8..10e0ad4da3 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -545,8 +545,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance = [] extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( - extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), + self.invokeai_diffuser.model, + extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps), ): yield PipelineIntermediateState( run_id=run_id, diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index dfd19ea964..79a0982cfe 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -10,6 +10,7 @@ import diffusers import psutil import torch from compel.cross_attention_control import Arguments +from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from torch import nn @@ -352,8 +353,7 @@ def restore_default_cross_attention( else: remove_attention_function(model) - -def override_cross_attention(model, context: Context, is_running_diffusers=False): +def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: - if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - if is_running_diffusers: - unet = model - old_attn_processors = unet.attn_processors - if torch.backends.mps.is_available(): - # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS - unet.set_attn_processor(SwapCrossAttnProcessor()) - else: - # try to re-use an existing slice size - default_slice_size = 4 - slice_size = next( - ( - p.slice_size - for p in old_attn_processors.values() - if type(p) is SlicedAttnProcessor - ), - default_slice_size, - ) - unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) - return old_attn_processors + old_attn_processors = unet.attn_processors + if torch.backends.mps.is_available(): + # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS + unet.set_attn_processor(SwapCrossAttnProcessor()) else: - context.register_cross_attention_modules(model) - inject_attention_function(model, context) - return None - + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) def get_cross_attention_modules( model, which: CrossAttentionType diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index b0c85e9fd3..245317bcde 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch +from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from typing_extensions import TypeAlias @@ -17,8 +18,8 @@ from .cross_attention_control import ( CrossAttentionType, SwapCrossAttnContext, get_cross_attention_modules, - override_cross_attention, restore_default_cross_attention, + setup_cross_attention_control_attention_processors, ) from .cross_attention_map_saving import AttentionMapSaver @@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent: self.cross_attention_control_context = None self.sequential_guidance = Globals.sequential_guidance + @classmethod @contextmanager def custom_attention_context( - self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int + cls, + unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs + extra_conditioning_info: Optional[ExtraConditioningInfo], + step_count: int ): - do_swap = ( - extra_conditioning_info is not None - and extra_conditioning_info.wants_cross_attention_control - ) - old_attn_processor = None - if do_swap: - old_attn_processor = self.override_cross_attention( - extra_conditioning_info, step_count=step_count - ) + old_attn_processors = None + if extra_conditioning_info and ( + extra_conditioning_info.wants_cross_attention_control + ): + old_attn_processors = unet.attn_processors + # Load lora conditions into the model + if extra_conditioning_info.wants_cross_attention_control: + cross_attention_control_context = Context( + arguments=extra_conditioning_info.cross_attention_control_args, + step_count=step_count, + ) + setup_cross_attention_control_attention_processors( + unet, + cross_attention_control_context, + ) + try: yield None finally: - if old_attn_processor is not None: - self.restore_default_cross_attention(old_attn_processor) + if old_attn_processors is not None: + unet.set_attn_processor(old_attn_processors) # TODO resuscitate attention map saving # self.remove_attention_map_saving()