mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove overrides logic for now
This commit is contained in:
parent
3a9dda9177
commit
7e00526999
@ -45,7 +45,7 @@ class StableDiffusionBackend:
|
|||||||
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
||||||
|
|
||||||
# ext: tiles? [override: step]
|
# ext: tiles? [override: step]
|
||||||
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager)
|
ctx.step_output = self.step(ctx, ext_manager)
|
||||||
|
|
||||||
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||||
# ext: preview[post_step, priority=low]
|
# ext: preview[post_step, priority=low]
|
||||||
@ -73,7 +73,7 @@ class StableDiffusionBackend:
|
|||||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||||
|
|
||||||
# ext: override apply_cfg
|
# ext: override apply_cfg
|
||||||
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx)
|
ctx.noise_pred = self.apply_cfg(ctx)
|
||||||
|
|
||||||
# ext: cfg_rescale [modify_noise_prediction]
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
# TODO: rename
|
# TODO: rename
|
||||||
|
@ -26,18 +26,6 @@ def callback(name: str, order: int = 0):
|
|||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
def override(name: str):
|
|
||||||
def _decorator(func):
|
|
||||||
func.__inj_info__ = {
|
|
||||||
"type": "override",
|
|
||||||
"name": name,
|
|
||||||
"order": None,
|
|
||||||
}
|
|
||||||
return func
|
|
||||||
|
|
||||||
return _decorator
|
|
||||||
|
|
||||||
|
|
||||||
class ExtensionBase:
|
class ExtensionBase:
|
||||||
def __init__(self, priority: int):
|
def __init__(self, priority: int):
|
||||||
self.priority = priority
|
self.priority = priority
|
||||||
|
@ -49,16 +49,6 @@ class ExtCallbacksApi(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ExtOverridesApi(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProxyCallsClass:
|
class ProxyCallsClass:
|
||||||
def __init__(self, handler):
|
def __init__(self, handler):
|
||||||
self._handler = handler
|
self._handler = handler
|
||||||
@ -86,17 +76,13 @@ class ExtensionsManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.extensions = []
|
self.extensions = []
|
||||||
|
|
||||||
self._overrides = {}
|
|
||||||
self._callbacks = {}
|
self._callbacks = {}
|
||||||
|
|
||||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
||||||
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
|
|
||||||
|
|
||||||
def add_extension(self, ext: ExtensionBase):
|
def add_extension(self, ext: ExtensionBase):
|
||||||
self.extensions.append(ext)
|
self.extensions.append(ext)
|
||||||
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
|
||||||
|
|
||||||
self._overrides.clear()
|
|
||||||
self._callbacks.clear()
|
self._callbacks.clear()
|
||||||
|
|
||||||
for ext in ordered_extensions:
|
for ext in ordered_extensions:
|
||||||
@ -107,20 +93,12 @@ class ExtensionsManager:
|
|||||||
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if inj_info.name in self._overrides:
|
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
||||||
raise Exception(f"Already overloaded - {inj_info.name}")
|
|
||||||
self._overrides[inj_info.name] = inj_info.function
|
|
||||||
|
|
||||||
def call_callback(self, name: str, *args, **kwargs):
|
def call_callback(self, name: str, *args, **kwargs):
|
||||||
if name in self._callbacks:
|
if name in self._callbacks:
|
||||||
self._callbacks[name](*args, **kwargs)
|
self._callbacks[name](*args, **kwargs)
|
||||||
|
|
||||||
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
|
|
||||||
if name in self._overrides:
|
|
||||||
return self._overrides[name](orig_func, *args, **kwargs)
|
|
||||||
else:
|
|
||||||
return orig_func(*args, **kwargs)
|
|
||||||
|
|
||||||
# TODO: is there any need in such high abstarction
|
# TODO: is there any need in such high abstarction
|
||||||
# @contextmanager
|
# @contextmanager
|
||||||
# def patch_extensions(self):
|
# def patch_extensions(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user