mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename modifiers to callbacks, convert order to int, a bit unify injection points
This commit is contained in:
@ -15,29 +15,29 @@ if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
||||
|
||||
|
||||
class ExtModifiersApi(ABC):
|
||||
class ExtCallbacksApi(ABC):
|
||||
@abstractmethod
|
||||
def pre_denoise_loop(self, ctx: DenoiseContext):
|
||||
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_denoise_loop(self, ctx: DenoiseContext):
|
||||
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_step(self, ctx: DenoiseContext):
|
||||
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_step(self, ctx: DenoiseContext):
|
||||
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def modify_noise_prediction(self, ctx: DenoiseContext):
|
||||
def modify_noise_prediction(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_unet_forward(self, ctx: DenoiseContext):
|
||||
def pre_unet_forward(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -51,7 +51,7 @@ class ExtOverridesApi(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def combine_noise(self, orig_func: Callable, ctx: DenoiseContext):
|
||||
def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
|
||||
@ -63,27 +63,19 @@ class ProxyCallsClass:
|
||||
return partial(self._handler, item)
|
||||
|
||||
|
||||
class ModifierInjectionPoint:
|
||||
class CallbackInjectionPoint:
|
||||
def __init__(self):
|
||||
self.first = []
|
||||
self.any = []
|
||||
self.last = []
|
||||
self.handlers = {}
|
||||
|
||||
def add(self, func: Callable, order: str):
|
||||
if order == "first":
|
||||
self.first.append(func)
|
||||
elif order == "last":
|
||||
self.last.append(func)
|
||||
else: # elif order == "any":
|
||||
self.any.append(func)
|
||||
def add(self, func: Callable, order: int):
|
||||
if order not in self.handlers:
|
||||
self.handlers[order] = []
|
||||
self.handlers[order].append(func)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
for func in self.first:
|
||||
func(*args, **kwargs)
|
||||
for func in self.any:
|
||||
func(*args, **kwargs)
|
||||
for func in reversed(self.last):
|
||||
func(*args, **kwargs)
|
||||
for order in sorted(self.handlers.keys(), reverse=True):
|
||||
for handler in self.handlers[order]:
|
||||
handler(*args, **kwargs)
|
||||
|
||||
|
||||
class ExtensionsManager:
|
||||
@ -91,9 +83,9 @@ class ExtensionsManager:
|
||||
self.extensions = []
|
||||
|
||||
self._overrides = {}
|
||||
self._modifiers = {}
|
||||
self._callbacks = {}
|
||||
|
||||
self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier)
|
||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
||||
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
|
||||
|
||||
def add_extension(self, ext: ExtensionBase):
|
||||
@ -101,23 +93,23 @@ class ExtensionsManager:
|
||||
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||
|
||||
self._overrides.clear()
|
||||
self._modifiers.clear()
|
||||
self._callbacks.clear()
|
||||
|
||||
for ext in ordered_extensions:
|
||||
for inj_info in ext.injections:
|
||||
if inj_info.type == "modifier":
|
||||
if inj_info.name not in self._modifiers:
|
||||
self._modifiers[inj_info.name] = ModifierInjectionPoint()
|
||||
self._modifiers[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
if inj_info.type == "callback":
|
||||
if inj_info.name not in self._callbacks:
|
||||
self._callbacks[inj_info.name] = CallbackInjectionPoint()
|
||||
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
|
||||
else:
|
||||
if inj_info.name in self._overrides:
|
||||
raise Exception(f"Already overloaded - {inj_info.name}")
|
||||
self._overrides[inj_info.name] = inj_info.function
|
||||
|
||||
def call_modifier(self, name: str, *args, **kwargs):
|
||||
if name in self._modifiers:
|
||||
self._modifiers[name](*args, **kwargs)
|
||||
def call_callback(self, name: str, *args, **kwargs):
|
||||
if name in self._callbacks:
|
||||
self._callbacks[name](*args, **kwargs)
|
||||
|
||||
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
|
||||
if name in self._overrides:
|
||||
|
Reference in New Issue
Block a user