mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Base code from draft PR
This commit is contained in:
195
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
195
invokeai/backend/stable_diffusion/extensions_manager.py
Normal file
@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Dict
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
||||
|
||||
|
||||
class ExtModifiersApi(ABC):
|
||||
@abstractmethod
|
||||
def pre_denoise_loop(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_denoise_loop(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_step(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_step(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def modify_noise_prediction(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_unet_forward(self, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
|
||||
class ExtOverridesApi(ABC):
|
||||
@abstractmethod
|
||||
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyCallsClass:
|
||||
def __init__(self, handler):
|
||||
self._handler = handler
|
||||
|
||||
def __getattr__(self, item):
|
||||
return partial(self._handler, item)
|
||||
|
||||
|
||||
class ModifierInjectionPoint:
|
||||
def __init__(self):
|
||||
self.first = []
|
||||
self.any = []
|
||||
self.last = []
|
||||
|
||||
def add(self, func: Callable, order: str):
|
||||
if order == "first":
|
||||
self.first.append(func)
|
||||
elif order == "last":
|
||||
self.last.append(func)
|
||||
else: # elif order == "any":
|
||||
self.any.append(func)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for func in self.first:
|
||||
func(*args, **kwargs)
|
||||
for func in self.any:
|
||||
func(*args, **kwargs)
|
||||
for func in reversed(self.last):
|
||||
func(*args, **kwargs)
|
||||
|
||||
|
||||
class ExtensionsManager:
|
||||
def __init__(self):
|
||||
self.extensions = []
|
||||
|
||||
self._overrides = {}
|
||||
self._modifiers = {}
|
||||
|
||||
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
|
||||
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
|
||||
|
||||
def add_extension(self, ext: ExtensionBase):
|
||||
self.extensions.append(ext)
|
||||
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||
|
||||
self._overrides.clear()
|
||||
self._modifiers.clear()
|
||||
|
||||
for ext in ordered_extensions:
|
||||
for inj_info in ext.injections:
|
||||
if inj_info.type == "modifier":
|
||||
if inj_info.name not in self._modifiers:
|
||||
self._modifiers[inj_info.name] = ModifierInjectionPoint()
|
||||
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
|
||||
else:
|
||||
if inj_info.name in self._overrides:
|
||||
raise Exception(f"Already overloaded - {inj_info.name}")
|
||||
self._overrides[inj_info.name] = inj_info.function
|
||||
|
||||
def call_modifier(self, name: str, *args, **kwargs):
|
||||
if name in self._modifiers:
|
||||
self._modifiers[name](*args, **kwargs)
|
||||
|
||||
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
|
||||
if name in self._overrides:
|
||||
return self._overrides[name](orig_func, *args, **kwargs)
|
||||
else:
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
# TODO: is there any need in such high abstarction
|
||||
# @contextmanager
|
||||
# def patch_extensions(self):
|
||||
# exit_stack = ExitStack()
|
||||
# try:
|
||||
# for ext in self.extensions:
|
||||
# exit_stack.enter_context(ext.patch_extension(self))
|
||||
#
|
||||
# yield None
|
||||
#
|
||||
# finally:
|
||||
# exit_stack.close()
|
||||
|
||||
@contextmanager
|
||||
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
|
||||
unet_orig_processors = unet.attn_processors
|
||||
exit_stack = ExitStack()
|
||||
try:
|
||||
# just to be sure that attentions have not same processor instance
|
||||
attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
attn_procs[name] = attn_processor_cls()
|
||||
unet.set_attn_processor(attn_procs)
|
||||
|
||||
for ext in self.extensions:
|
||||
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
exit_stack.close()
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
exit_stack = ExitStack()
|
||||
try:
|
||||
changed_keys = set()
|
||||
changed_unknown_keys = {}
|
||||
|
||||
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||
for ext in ordered_extensions:
|
||||
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
|
||||
if patch_result is None:
|
||||
continue
|
||||
new_keys, new_unk_keys = patch_result
|
||||
changed_keys.update(new_keys)
|
||||
# skip already seen keys, as new weight might be changed
|
||||
for k, v in new_unk_keys.items():
|
||||
if k in changed_unknown_keys:
|
||||
continue
|
||||
changed_unknown_keys[k] = v
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
exit_stack.close()
|
||||
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key in changed_keys:
|
||||
weight = state_dict[module_key]
|
||||
unet.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
for module_key, weight in changed_unknown_keys.items():
|
||||
unet.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
Reference in New Issue
Block a user