mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Ryan's suggested changes to extension manager/extensions
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -12,102 +10,50 @@ from invokeai.app.services.session_processor.session_processor_common import Can
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
|
||||
class ExtCallbacksApi(ABC):
|
||||
@abstractmethod
|
||||
def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyCallsClass:
|
||||
def __init__(self, handler):
|
||||
self._handler = handler
|
||||
|
||||
def __getattr__(self, item):
|
||||
return partial(self._handler, item)
|
||||
|
||||
|
||||
class CallbackInjectionPoint:
|
||||
def __init__(self):
|
||||
self.handlers = {}
|
||||
|
||||
def add(self, func: Callable, order: int):
|
||||
if order not in self.handlers:
|
||||
self.handlers[order] = []
|
||||
self.handlers[order].append(func)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for order in sorted(self.handlers.keys(), reverse=True):
|
||||
for handler in self.handlers[order]:
|
||||
handler(*args, **kwargs)
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
|
||||
|
||||
|
||||
class ExtensionsManager:
|
||||
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||
self.extensions: List[ExtensionBase] = []
|
||||
self._is_canceled = is_canceled
|
||||
|
||||
self._callbacks: Dict[str, CallbackInjectionPoint] = {}
|
||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
||||
self._extensions: List[ExtensionBase] = []
|
||||
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
def add_extension(self, ext: ExtensionBase):
|
||||
self.extensions.append(ext)
|
||||
def add_extension(self, extension: ExtensionBase):
|
||||
self._extensions.append(extension)
|
||||
self._regenerate_ordered_callbacks()
|
||||
|
||||
self._callbacks.clear()
|
||||
def _regenerate_ordered_callbacks(self):
|
||||
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||
self._ordered_callbacks = {}
|
||||
|
||||
for ext in self.extensions:
|
||||
for inj_info in ext.injections:
|
||||
if inj_info.type == "callback":
|
||||
if inj_info.name not in self._callbacks:
|
||||
self._callbacks[inj_info.name] = CallbackInjectionPoint()
|
||||
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
# Fill the ordered callbacks dictionary.
|
||||
for extension in self._extensions:
|
||||
for callback_type, callbacks in extension.get_callbacks().items():
|
||||
if callback_type not in self._ordered_callbacks:
|
||||
self._ordered_callbacks[callback_type] = []
|
||||
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
||||
# Sort each callback list.
|
||||
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||
|
||||
def call_callback(self, name: str, *args, **kwargs):
|
||||
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||
# TODO: add to patchers too?
|
||||
# and if so, should it be only in beginning of function or in for loop
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
if name in self._callbacks:
|
||||
self._callbacks[name](*args, **kwargs)
|
||||
callbacks = self._ordered_callbacks.get(callback_type, [])
|
||||
for cb in callbacks:
|
||||
cb.function(ctx)
|
||||
|
||||
@contextmanager
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self.extensions:
|
||||
for ext in self._extensions:
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
|
||||
yield None
|
||||
|
Reference in New Issue
Block a user