2024-07-12 17:31:26 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from contextlib import ExitStack, contextmanager
|
|
|
|
from functools import partial
|
|
|
|
from typing import TYPE_CHECKING, Callable, Dict
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
|
|
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
|
|
|
|
|
|
|
|
2024-07-12 19:01:05 +00:00
|
|
|
class ExtCallbacksApi(ABC):
|
2024-07-12 19:44:00 +00:00
|
|
|
@abstractmethod
|
|
|
|
def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
|
|
pass
|
|
|
|
|
2024-07-12 17:31:26 +00:00
|
|
|
@abstractmethod
|
2024-07-12 19:01:05 +00:00
|
|
|
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-07-12 19:01:05 +00:00
|
|
|
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-07-12 19:01:05 +00:00
|
|
|
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-07-12 19:01:05 +00:00
|
|
|
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-07-12 19:44:00 +00:00
|
|
|
def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-07-12 19:44:00 +00:00
|
|
|
def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-07-12 19:44:00 +00:00
|
|
|
def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
2024-07-12 17:31:26 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class ProxyCallsClass:
|
|
|
|
def __init__(self, handler):
|
|
|
|
self._handler = handler
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
return partial(self._handler, item)
|
|
|
|
|
|
|
|
|
2024-07-12 19:01:05 +00:00
|
|
|
class CallbackInjectionPoint:
|
2024-07-12 17:31:26 +00:00
|
|
|
def __init__(self):
|
2024-07-12 19:01:05 +00:00
|
|
|
self.handlers = {}
|
|
|
|
|
|
|
|
def add(self, func: Callable, order: int):
|
|
|
|
if order not in self.handlers:
|
|
|
|
self.handlers[order] = []
|
|
|
|
self.handlers[order].append(func)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
2024-07-12 19:01:05 +00:00
|
|
|
for order in sorted(self.handlers.keys(), reverse=True):
|
|
|
|
for handler in self.handlers[order]:
|
|
|
|
handler(*args, **kwargs)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ExtensionsManager:
|
|
|
|
def __init__(self):
|
|
|
|
self.extensions = []
|
|
|
|
|
2024-07-12 19:01:05 +00:00
|
|
|
self._callbacks = {}
|
|
|
|
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
def add_extension(self, ext: ExtensionBase):
|
|
|
|
self.extensions.append(ext)
|
|
|
|
|
2024-07-12 19:01:05 +00:00
|
|
|
self._callbacks.clear()
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-12 21:44:21 +00:00
|
|
|
for ext in self.extensions:
|
2024-07-12 17:31:26 +00:00
|
|
|
for inj_info in ext.injections:
|
2024-07-12 19:01:05 +00:00
|
|
|
if inj_info.type == "callback":
|
|
|
|
if inj_info.name not in self._callbacks:
|
|
|
|
self._callbacks[inj_info.name] = CallbackInjectionPoint()
|
|
|
|
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
else:
|
2024-07-12 21:28:56 +00:00
|
|
|
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-12 19:01:05 +00:00
|
|
|
def call_callback(self, name: str, *args, **kwargs):
|
|
|
|
if name in self._callbacks:
|
|
|
|
self._callbacks[name](*args, **kwargs)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
@contextmanager
|
2024-07-16 17:03:29 +00:00
|
|
|
def patch_extensions(self, context: DenoiseContext):
|
|
|
|
with ExitStack() as exit_stack:
|
2024-07-12 17:31:26 +00:00
|
|
|
for ext in self.extensions:
|
2024-07-16 17:03:29 +00:00
|
|
|
exit_stack.enter_context(ext.patch_extension(context))
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
yield None
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
2024-07-17 00:40:27 +00:00
|
|
|
# TODO: create logic in PR with extension which uses it
|
|
|
|
yield None
|