mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Change attention processor apply logic
This commit is contained in:
parent
608cbe3f5c
commit
cec345cb5c
@ -776,6 +776,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
attention_processor_cls=CustomAttnProcessor2_0,
|
||||||
),
|
),
|
||||||
unet=None,
|
unet=None,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@ -797,8 +798,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
unet_info.model_on_device() as (model_state_dict, unet),
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
|
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||||
# ext: controlnet
|
# ext: controlnet
|
||||||
ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0),
|
ext_manager.patch_extensions(unet),
|
||||||
# ext: freeu, seamless, ip adapter, lora
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
ext_manager.patch_unet(model_state_dict, unet),
|
ext_manager.patch_unet(model_state_dict, unet),
|
||||||
):
|
):
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -32,8 +32,27 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
@staticmethod
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
|
||||||
|
"""A context manager that patches `unet` with the provided attention processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (UNet2DConditionModel): The UNet model to patch.
|
||||||
|
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
|
||||||
|
"""
|
||||||
|
unet_orig_processors = unet.attn_processors
|
||||||
|
try:
|
||||||
|
# create separate instance for each attention, to be able modify each attention separately
|
||||||
|
new_attn_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
|
||||||
|
unet.set_attn_processor(new_attn_processors)
|
||||||
|
|
||||||
|
yield None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
unet.set_attn_processor(unet_orig_processors)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
assert "." not in lora_key
|
assert "." not in lora_key
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -38,6 +38,7 @@ class DenoiseInputs:
|
|||||||
seed: int
|
seed: int
|
||||||
timesteps: torch.Tensor
|
timesteps: torch.Tensor
|
||||||
init_timestep: torch.Tensor
|
init_timestep: torch.Tensor
|
||||||
|
attention_processor_cls: Type[Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,10 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InjectionInfo:
|
class InjectionInfo:
|
||||||
@ -37,7 +42,7 @@ class ExtensionBase:
|
|||||||
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
|
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_attention_processor(self, attention_processor_cls: object):
|
def patch_extension(self, context: DenoiseContext):
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -98,39 +98,14 @@ class ExtensionsManager:
|
|||||||
if name in self._callbacks:
|
if name in self._callbacks:
|
||||||
self._callbacks[name](*args, **kwargs)
|
self._callbacks[name](*args, **kwargs)
|
||||||
|
|
||||||
# TODO: is there any need in such high abstarction
|
|
||||||
# @contextmanager
|
|
||||||
# def patch_extensions(self):
|
|
||||||
# exit_stack = ExitStack()
|
|
||||||
# try:
|
|
||||||
# for ext in self.extensions:
|
|
||||||
# exit_stack.enter_context(ext.patch_extension(self))
|
|
||||||
#
|
|
||||||
# yield None
|
|
||||||
#
|
|
||||||
# finally:
|
|
||||||
# exit_stack.close()
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
|
def patch_extensions(self, context: DenoiseContext):
|
||||||
unet_orig_processors = unet.attn_processors
|
with ExitStack() as exit_stack:
|
||||||
exit_stack = ExitStack()
|
|
||||||
try:
|
|
||||||
# just to be sure that attentions have not same processor instance
|
|
||||||
attn_procs = {}
|
|
||||||
for name in unet.attn_processors.keys():
|
|
||||||
attn_procs[name] = attn_processor_cls()
|
|
||||||
unet.set_attn_processor(attn_procs)
|
|
||||||
|
|
||||||
for ext in self.extensions:
|
for ext in self.extensions:
|
||||||
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
|
exit_stack.enter_context(ext.patch_extension(context))
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
finally:
|
|
||||||
unet.set_attn_processor(unet_orig_processors)
|
|
||||||
exit_stack.close()
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
exit_stack = ExitStack()
|
exit_stack = ExitStack()
|
||||||
|
Loading…
Reference in New Issue
Block a user