diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 81b92d4fa7..6005bc83e0 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -776,6 +776,7 @@ class DenoiseLatentsInvocation(BaseInvocation): seed=seed, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, + attention_processor_cls=CustomAttnProcessor2_0, ), unet=None, scheduler=scheduler, @@ -797,8 +798,9 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), # ext: controlnet - ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0), + ext_manager.patch_extensions(unet), # ext: freeu, seamless, ip adapter, lora ext_manager.patch_unet(model_state_dict, unet), ): diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 8c7a62c371..d31cb6bdef 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras): """ -# TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: + @staticmethod + @contextmanager + def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]): + """A context manager that patches `unet` with the provided attention processor. + + Args: + unet (UNet2DConditionModel): The UNet model to patch. + processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...). + """ + unet_orig_processors = unet.attn_processors + try: + # create separate instance for each attention, to be able modify each attention separately + new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()} + unet.set_attn_processor(new_attn_processors) + + yield None + + finally: + unet.set_attn_processor(unet_orig_processors) + @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 2a00052fd1..26c3b02c3b 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union import torch from diffusers import UNet2DConditionModel @@ -38,6 +38,7 @@ class DenoiseInputs: seed: int timesteps: torch.Tensor init_timestep: torch.Tensor + attention_processor_cls: Type[Any] @dataclass diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 3effa77da4..2aaf49e3b9 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -1,10 +1,15 @@ +from __future__ import annotations + from contextlib import contextmanager from dataclasses import dataclass -from typing import Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + @dataclass class InjectionInfo: @@ -37,7 +42,7 @@ class ExtensionBase: self.injections.append(InjectionInfo(**func.__inj_info__, function=func)) @contextmanager - def patch_attention_processor(self, attention_processor_cls: object): + def patch_extension(self, context: DenoiseContext): yield None @contextmanager diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 876fd96d39..e747579d8b 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -98,39 +98,14 @@ class ExtensionsManager: if name in self._callbacks: self._callbacks[name](*args, **kwargs) - # TODO: is there any need in such high abstarction - # @contextmanager - # def patch_extensions(self): - # exit_stack = ExitStack() - # try: - # for ext in self.extensions: - # exit_stack.enter_context(ext.patch_extension(self)) - # - # yield None - # - # finally: - # exit_stack.close() - @contextmanager - def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object): - unet_orig_processors = unet.attn_processors - exit_stack = ExitStack() - try: - # just to be sure that attentions have not same processor instance - attn_procs = {} - for name in unet.attn_processors.keys(): - attn_procs[name] = attn_processor_cls() - unet.set_attn_processor(attn_procs) - + def patch_extensions(self, context: DenoiseContext): + with ExitStack() as exit_stack: for ext in self.extensions: - exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls)) + exit_stack.enter_context(ext.patch_extension(context)) yield None - finally: - unet.set_attn_processor(unet_orig_processors) - exit_stack.close() - @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): exit_stack = ExitStack()