Remove overrides logic for now

This commit is contained in:
Sergey Borisov 2024-07-13 00:28:56 +03:00
parent 3a9dda9177
commit 7e00526999
3 changed files with 3 additions and 37 deletions

View File

@ -45,7 +45,7 @@ class StableDiffusionBackend:
ext_manager.callbacks.pre_step(ctx, ext_manager) ext_manager.callbacks.pre_step(ctx, ext_manager)
# ext: tiles? [override: step] # ext: tiles? [override: step]
ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager) ctx.step_output = self.step(ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low] # ext: preview[post_step, priority=low]
@ -73,7 +73,7 @@ class StableDiffusionBackend:
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
# ext: override apply_cfg # ext: override apply_cfg
ctx.noise_pred = ext_manager.overrides.apply_cfg(self.apply_cfg, ctx) ctx.noise_pred = self.apply_cfg(ctx)
# ext: cfg_rescale [modify_noise_prediction] # ext: cfg_rescale [modify_noise_prediction]
# TODO: rename # TODO: rename

View File

@ -26,18 +26,6 @@ def callback(name: str, order: int = 0):
return _decorator return _decorator
def override(name: str):
def _decorator(func):
func.__inj_info__ = {
"type": "override",
"name": name,
"order": None,
}
return func
return _decorator
class ExtensionBase: class ExtensionBase:
def __init__(self, priority: int): def __init__(self, priority: int):
self.priority = priority self.priority = priority

View File

@ -49,16 +49,6 @@ class ExtCallbacksApi(ABC):
pass pass
class ExtOverridesApi(ABC):
@abstractmethod
def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager):
pass
@abstractmethod
def apply_cfg(self, orig_func: Callable, ctx: DenoiseContext):
pass
class ProxyCallsClass: class ProxyCallsClass:
def __init__(self, handler): def __init__(self, handler):
self._handler = handler self._handler = handler
@ -86,17 +76,13 @@ class ExtensionsManager:
def __init__(self): def __init__(self):
self.extensions = [] self.extensions = []
self._overrides = {}
self._callbacks = {} self._callbacks = {}
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override)
def add_extension(self, ext: ExtensionBase): def add_extension(self, ext: ExtensionBase):
self.extensions.append(ext) self.extensions.append(ext)
ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority)
self._overrides.clear()
self._callbacks.clear() self._callbacks.clear()
for ext in ordered_extensions: for ext in ordered_extensions:
@ -107,20 +93,12 @@ class ExtensionsManager:
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order) self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
else: else:
if inj_info.name in self._overrides: raise Exception(f"Unsupported injection type: {inj_info.type}")
raise Exception(f"Already overloaded - {inj_info.name}")
self._overrides[inj_info.name] = inj_info.function
def call_callback(self, name: str, *args, **kwargs): def call_callback(self, name: str, *args, **kwargs):
if name in self._callbacks: if name in self._callbacks:
self._callbacks[name](*args, **kwargs) self._callbacks[name](*args, **kwargs)
def call_override(self, name: str, orig_func: Callable, *args, **kwargs):
if name in self._overrides:
return self._overrides[name](orig_func, *args, **kwargs)
else:
return orig_func(*args, **kwargs)
# TODO: is there any need in such high abstarction # TODO: is there any need in such high abstarction
# @contextmanager # @contextmanager
# def patch_extensions(self): # def patch_extensions(self):