Change attention processor apply logic

This commit is contained in:
Sergey Borisov 2024-07-16 20:03:29 +03:00
parent 608cbe3f5c
commit cec345cb5c
5 changed files with 36 additions and 34 deletions

View File

@ -776,6 +776,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed, seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
), ),
unet=None, unet=None,
scheduler=scheduler, scheduler=scheduler,
@ -797,8 +798,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
unet_info.model_on_device() as (model_state_dict, unet), unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet # ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0), ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora # ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet), ext_manager.patch_unet(model_state_dict, unet),
): ):

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import torch import torch
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
""" """
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher: class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.
Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
try:
# create separate instance for each attention, to be able modify each attention separately
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
unet.set_attn_processor(new_attn_processors)
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod @staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key assert "." not in lora_key

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
@ -38,6 +38,7 @@ class DenoiseInputs:
seed: int seed: int
timesteps: torch.Tensor timesteps: torch.Tensor
init_timestep: torch.Tensor init_timestep: torch.Tensor
attention_processor_cls: Type[Any]
@dataclass @dataclass

View File

@ -1,10 +1,15 @@
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@dataclass @dataclass
class InjectionInfo: class InjectionInfo:
@ -37,7 +42,7 @@ class ExtensionBase:
self.injections.append(InjectionInfo(**func.__inj_info__, function=func)) self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
@contextmanager @contextmanager
def patch_attention_processor(self, attention_processor_cls: object): def patch_extension(self, context: DenoiseContext):
yield None yield None
@contextmanager @contextmanager

View File

@ -98,39 +98,14 @@ class ExtensionsManager:
if name in self._callbacks: if name in self._callbacks:
self._callbacks[name](*args, **kwargs) self._callbacks[name](*args, **kwargs)
# TODO: is there any need in such high abstarction
# @contextmanager
# def patch_extensions(self):
# exit_stack = ExitStack()
# try:
# for ext in self.extensions:
# exit_stack.enter_context(ext.patch_extension(self))
#
# yield None
#
# finally:
# exit_stack.close()
@contextmanager @contextmanager
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object): def patch_extensions(self, context: DenoiseContext):
unet_orig_processors = unet.attn_processors with ExitStack() as exit_stack:
exit_stack = ExitStack()
try:
# just to be sure that attentions have not same processor instance
attn_procs = {}
for name in unet.attn_processors.keys():
attn_procs[name] = attn_processor_cls()
unet.set_attn_processor(attn_procs)
for ext in self.extensions: for ext in self.extensions:
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls)) exit_stack.enter_context(ext.patch_extension(context))
yield None yield None
finally:
unet.set_attn_processor(unet_orig_processors)
exit_stack.close()
@contextmanager @contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
exit_stack = ExitStack() exit_stack = ExitStack()