mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove overrides logic for now
This commit is contained in:
parent
3a9dda9177
commit
7e00526999
@ -45,7 +45,7 @@ class StableDiffusionBackend:
|
||||
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
||||
|
||||
# ext: tiles? [override: step]
|
||||
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
|
||||
ctx.step_output = self.step(ctx, ext_manager)
|
||||
|
||||
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||
# ext: preview[post_step, priority=low]
|
||||
@ -73,7 +73,7 @@ class StableDiffusionBackend:
|
||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||
|
||||
# ext: override apply_cfg
|
||||
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx)
|
||||
ctx.noise_pred = self.apply_cfg(ctx)
|
||||
|
||||
# ext: cfg_rescale [modify_noise_prediction]
|
||||
# TODO: rename
|
||||
|
@ -26,18 +26,6 @@ def callback(name: str, order: int = 0):
|
||||
return _decorator
|
||||
|
||||
|
||||
def override(name: str):
|
||||
def _decorator(func):
|
||||
func.__inj_info__ = {
|
||||
"type": "override",
|
||||
"name": name,
|
||||
"order": None,
|
||||
}
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class ExtensionBase:
|
||||
def __init__(self, priority: int):
|
||||
self.priority = priority
|
||||
|
@ -49,16 +49,6 @@ class ExtCallbacksApi(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class ExtOverridesApi(ABC):
|
||||
@abstractmethod
|
||||
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext):
|
||||
pass
|
||||
|
||||
|
||||
class ProxyCallsClass:
|
||||
def __init__(self, handler):
|
||||
self._handler = handler
|
||||
@ -86,17 +76,13 @@ class ExtensionsManager:
|
||||
def __init__(self):
|
||||
self.extensions = []
|
||||
|
||||
self._overrides = {}
|
||||
self._callbacks = {}
|
||||
|
||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
||||
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
|
||||
|
||||
def add_extension(self, ext: ExtensionBase):
|
||||
self.extensions.append(ext)
|
||||
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||
|
||||
self._overrides.clear()
|
||||
self._callbacks.clear()
|
||||
|
||||
for ext in ordered_extensions:
|
||||
@ -107,20 +93,12 @@ class ExtensionsManager:
|
||||
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
||||
|
||||
else:
|
||||
if inj_info.name in self._overrides:
|
||||
raise Exception(f"Already overloaded - {inj_info.name}")
|
||||
self._overrides[inj_info.name] = inj_info.function
|
||||
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
||||
|
||||
def call_callback(self, name: str, *args, **kwargs):
|
||||
if name in self._callbacks:
|
||||
self._callbacks[name](*args, **kwargs)
|
||||
|
||||
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
|
||||
if name in self._overrides:
|
||||
return self._overrides[name](orig_func, *args, **kwargs)
|
||||
else:
|
||||
return orig_func(*args, **kwargs)
|
||||
|
||||
# TODO: is there any need in such high abstarction
|
||||
# @contextmanager
|
||||
# def patch_extensions(self):
|
||||
|
Loading…
Reference in New Issue
Block a user