mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Ryan's suggested changes to extension manager/extensions
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
parent
710dc6b487
commit
0c56d4a581
@ -57,6 +57,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
@ -790,7 +791,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
# ext: t2i/ip adapter
|
||||||
ext_manager.callbacks.setup(denoise_ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
unet_info = context.models.load(self.unet.unet)
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
|
@ -8,6 +8,7 @@ from tqdm.auto import tqdm
|
|||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
@ -41,23 +42,23 @@ class StableDiffusionBackend:
|
|||||||
|
|
||||||
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
|
||||||
# ext: preview[pre_denoise_loop, priority=low]
|
# ext: preview[pre_denoise_loop, priority=low]
|
||||||
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
|
||||||
|
|
||||||
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
|
||||||
# ext: inpaint (apply mask to latents on non-inpaint models)
|
# ext: inpaint (apply mask to latents on non-inpaint models)
|
||||||
ext_manager.callbacks.pre_step(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
|
||||||
|
|
||||||
# ext: tiles? [override: step]
|
# ext: tiles? [override: step]
|
||||||
ctx.step_output = self.step(ctx, ext_manager)
|
ctx.step_output = self.step(ctx, ext_manager)
|
||||||
|
|
||||||
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
|
||||||
# ext: preview[post_step, priority=low]
|
# ext: preview[post_step, priority=low]
|
||||||
ext_manager.callbacks.post_step(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
|
||||||
|
|
||||||
ctx.latents = ctx.step_output.prev_sample
|
ctx.latents = ctx.step_output.prev_sample
|
||||||
|
|
||||||
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
# ext: inpaint[post_denoise_loop] (restore unmasked part)
|
||||||
ext_manager.callbacks.post_denoise_loop(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
|
||||||
return ctx.latents
|
return ctx.latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -80,7 +81,7 @@ class StableDiffusionBackend:
|
|||||||
|
|
||||||
# ext: cfg_rescale [modify_noise_prediction]
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
# TODO: rename
|
# TODO: rename
|
||||||
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||||
@ -120,14 +121,14 @@ class StableDiffusionBackend:
|
|||||||
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
|
||||||
|
|
||||||
# ext: controlnet/ip/t2i [pre_unet]
|
# ext: controlnet/ip/t2i [pre_unet]
|
||||||
ext_manager.callbacks.pre_unet(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
|
||||||
|
|
||||||
# ext: inpaint [pre_unet, priority=low]
|
# ext: inpaint [pre_unet, priority=low]
|
||||||
# or
|
# or
|
||||||
# ext: inpaint [override: unet_forward]
|
# ext: inpaint [override: unet_forward]
|
||||||
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
|
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
|
||||||
|
|
||||||
ext_manager.callbacks.post_unet(ctx, ext_manager)
|
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
|
||||||
|
|
||||||
# clean up locals
|
# clean up locals
|
||||||
ctx.unet_kwargs = None
|
ctx.unet_kwargs = None
|
||||||
|
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionCallbackType(Enum):
|
||||||
|
SETUP = "setup"
|
||||||
|
PRE_DENOISE_LOOP = "pre_denoise_loop"
|
||||||
|
POST_DENOISE_LOOP = "post_denoise_loop"
|
||||||
|
PRE_STEP = "pre_step"
|
||||||
|
POST_STEP = "post_step"
|
||||||
|
PRE_UNET = "pre_unet"
|
||||||
|
POST_UNET = "post_unet"
|
||||||
|
POST_APPLY_CFG = "post_apply_cfg"
|
@ -2,44 +2,54 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InjectionInfo:
|
class CallbackMetadata:
|
||||||
type: str
|
callback_type: ExtensionCallbackType
|
||||||
name: str
|
order: int
|
||||||
order: Optional[int]
|
|
||||||
function: Callable
|
|
||||||
|
|
||||||
|
|
||||||
def callback(name: str, order: int = 0):
|
@dataclass
|
||||||
def _decorator(func):
|
class CallbackFunctionWithMetadata:
|
||||||
func.__inj_info__ = {
|
metadata: CallbackMetadata
|
||||||
"type": "callback",
|
function: Callable[[DenoiseContext], None]
|
||||||
"name": name,
|
|
||||||
"order": order,
|
|
||||||
}
|
def callback(callback_type: ExtensionCallbackType, order: int = 0):
|
||||||
return func
|
def _decorator(function):
|
||||||
|
function._ext_metadata = CallbackMetadata(
|
||||||
|
callback_type=callback_type,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
return function
|
||||||
|
|
||||||
return _decorator
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
class ExtensionBase:
|
class ExtensionBase:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.injections: List[InjectionInfo] = []
|
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
|
# Register all of the callback methods for this instance.
|
||||||
for func_name in dir(self):
|
for func_name in dir(self):
|
||||||
func = getattr(self, func_name)
|
func = getattr(self, func_name)
|
||||||
if not callable(func) or not hasattr(func, "__inj_info__"):
|
metadata = getattr(func, "_ext_metadata", None)
|
||||||
continue
|
if metadata is not None and isinstance(metadata, CallbackMetadata):
|
||||||
|
if metadata.callback_type not in self._callbacks:
|
||||||
|
self._callbacks[metadata.callback_type] = []
|
||||||
|
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
|
||||||
|
|
||||||
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
|
def get_callbacks(self):
|
||||||
|
return self._callbacks
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_extension(self, context: DenoiseContext):
|
def patch_extension(self, context: DenoiseContext):
|
||||||
|
@ -5,11 +5,11 @@ from typing import TYPE_CHECKING, Callable, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: change event to accept image instead of latents
|
# TODO: change event to accept image instead of latents
|
||||||
@ -29,8 +29,8 @@ class PreviewExt(ExtensionBase):
|
|||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
|
||||||
# do last so that all other changes shown
|
# do last so that all other changes shown
|
||||||
@callback("pre_denoise_loop", order=1000)
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
def initial_preview(self, ctx: DenoiseContext):
|
||||||
self.callback(
|
self.callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
step=-1,
|
step=-1,
|
||||||
@ -42,8 +42,8 @@ class PreviewExt(ExtensionBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# do last so that all other changes shown
|
# do last so that all other changes shown
|
||||||
@callback("post_step", order=1000)
|
@callback(ExtensionCallbackType.POST_STEP, order=1000)
|
||||||
def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
def step_preview(self, ctx: DenoiseContext):
|
||||||
if hasattr(ctx.step_output, "denoised"):
|
if hasattr(ctx.step_output, "denoised"):
|
||||||
predicted_original = ctx.step_output.denoised
|
predicted_original = ctx.step_output.denoised
|
||||||
elif hasattr(ctx.step_output, "pred_original_sample"):
|
elif hasattr(ctx.step_output, "pred_original_sample"):
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from functools import partial
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -12,102 +10,50 @@ from invokeai.app.services.session_processor.session_processor_common import Can
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
|
||||||
|
|
||||||
class ExtCallbacksApi(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ProxyCallsClass:
|
|
||||||
def __init__(self, handler):
|
|
||||||
self._handler = handler
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
|
||||||
return partial(self._handler, item)
|
|
||||||
|
|
||||||
|
|
||||||
class CallbackInjectionPoint:
|
|
||||||
def __init__(self):
|
|
||||||
self.handlers = {}
|
|
||||||
|
|
||||||
def add(self, func: Callable, order: int):
|
|
||||||
if order not in self.handlers:
|
|
||||||
self.handlers[order] = []
|
|
||||||
self.handlers[order].append(func)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
for order in sorted(self.handlers.keys(), reverse=True):
|
|
||||||
for handler in self.handlers[order]:
|
|
||||||
handler(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class ExtensionsManager:
|
class ExtensionsManager:
|
||||||
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||||
self.extensions: List[ExtensionBase] = []
|
|
||||||
self._is_canceled = is_canceled
|
self._is_canceled = is_canceled
|
||||||
|
|
||||||
self._callbacks: Dict[str, CallbackInjectionPoint] = {}
|
self._extensions: List[ExtensionBase] = []
|
||||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
def add_extension(self, ext: ExtensionBase):
|
def add_extension(self, extension: ExtensionBase):
|
||||||
self.extensions.append(ext)
|
self._extensions.append(extension)
|
||||||
|
self._regenerate_ordered_callbacks()
|
||||||
|
|
||||||
self._callbacks.clear()
|
def _regenerate_ordered_callbacks(self):
|
||||||
|
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||||
|
self._ordered_callbacks = {}
|
||||||
|
|
||||||
for ext in self.extensions:
|
# Fill the ordered callbacks dictionary.
|
||||||
for inj_info in ext.injections:
|
for extension in self._extensions:
|
||||||
if inj_info.type == "callback":
|
for callback_type, callbacks in extension.get_callbacks().items():
|
||||||
if inj_info.name not in self._callbacks:
|
if callback_type not in self._ordered_callbacks:
|
||||||
self._callbacks[inj_info.name] = CallbackInjectionPoint()
|
self._ordered_callbacks[callback_type] = []
|
||||||
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
|
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||||
|
|
||||||
else:
|
# Sort each callback list.
|
||||||
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||||
|
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||||
|
|
||||||
def call_callback(self, name: str, *args, **kwargs):
|
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||||
# TODO: add to patchers too?
|
# TODO: add to patchers too?
|
||||||
# and if so, should it be only in beginning of function or in for loop
|
# and if so, should it be only in beginning of function or in for loop
|
||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
if name in self._callbacks:
|
callbacks = self._ordered_callbacks.get(callback_type, [])
|
||||||
self._callbacks[name](*args, **kwargs)
|
for cb in callbacks:
|
||||||
|
cb.function(ctx)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_extensions(self, context: DenoiseContext):
|
def patch_extensions(self, context: DenoiseContext):
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
for ext in self.extensions:
|
for ext in self._extensions:
|
||||||
exit_stack.enter_context(ext.patch_extension(context))
|
exit_stack.enter_context(ext.patch_extension(context))
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user