Change attention processor apply logic

This commit is contained in:
Sergey Borisov 2024-07-16 20:03:29 +03:00
parent 608cbe3f5c
commit cec345cb5c
5 changed files with 36 additions and 34 deletions

View File

@ -776,6 +776,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
@ -797,8 +798,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
import numpy as np
import torch
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.
Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
try:
# create separate instance for each attention, to be able modify each attention separately
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
unet.set_attn_processor(new_attn_processors)
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
import torch
from diffusers import UNet2DConditionModel
@ -38,6 +38,7 @@ class DenoiseInputs:
seed: int
timesteps: torch.Tensor
init_timestep: torch.Tensor
attention_processor_cls: Type[Any]
@dataclass

View File

@ -1,10 +1,15 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@dataclass
class InjectionInfo:
@ -37,7 +42,7 @@ class ExtensionBase:
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
@contextmanager
def patch_attention_processor(self, attention_processor_cls: object):
def patch_extension(self, context: DenoiseContext):
yield None
@contextmanager

View File

@ -98,39 +98,14 @@ class ExtensionsManager:
if name in self._callbacks:
self._callbacks[name](*args, **kwargs)
# TODO: is there any need in such high abstarction
# @contextmanager
# def patch_extensions(self):
# exit_stack = ExitStack()
# try:
# for ext in self.extensions:
# exit_stack.enter_context(ext.patch_extension(self))
#
# yield None
#
# finally:
# exit_stack.close()
@contextmanager
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
unet_orig_processors = unet.attn_processors
exit_stack = ExitStack()
try:
# just to be sure that attentions have not same processor instance
attn_procs = {}
for name in unet.attn_processors.keys():
attn_procs[name] = attn_processor_cls()
unet.set_attn_processor(attn_procs)
def patch_extensions(self, context: DenoiseContext):
with ExitStack() as exit_stack:
for ext in self.extensions:
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
exit_stack.enter_context(ext.patch_extension(context))
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
exit_stack.close()
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
exit_stack = ExitStack()