from __future__ import annotations from contextlib import ExitStack, contextmanager from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel from invokeai.app.services.session_processor.session_processor_common import CanceledException if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase class ExtensionsManager: def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): self._is_canceled = is_canceled self._extensions: List[ExtensionBase] = [] self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} def add_extension(self, extension: ExtensionBase): self._extensions.append(extension) self._regenerate_ordered_callbacks() def _regenerate_ordered_callbacks(self): """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" self._ordered_callbacks = {} # Fill the ordered callbacks dictionary. for extension in self._extensions: for callback_type, callbacks in extension.get_callbacks().items(): if callback_type not in self._ordered_callbacks: self._ordered_callbacks[callback_type] = [] self._ordered_callbacks[callback_type].extend(callbacks) # Sort each callback list. for callback_type, callbacks in self._ordered_callbacks.items(): self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): # TODO: add to patchers too? # and if so, should it be only in beginning of function or in for loop if self._is_canceled and self._is_canceled(): raise CanceledException callbacks = self._ordered_callbacks.get(callback_type, []) for cb in callbacks: cb.function(ctx) @contextmanager def patch_extensions(self, context: DenoiseContext): with ExitStack() as exit_stack: for ext in self._extensions: exit_stack.enter_context(ext.patch_extension(context)) yield None @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): # TODO: create logic in PR with extension which uses it yield None