from __future__ import annotations from contextlib import ExitStack, contextmanager from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel from invokeai.app.services.session_processor.session_processor_common import CanceledException if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase class ExtensionsManager: def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): self._is_canceled = is_canceled # A list of extensions in the order that they were added to the ExtensionsManager. self._extensions: List[ExtensionBase] = [] self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} def add_extension(self, extension: ExtensionBase): self._extensions.append(extension) self._regenerate_ordered_callbacks() def _regenerate_ordered_callbacks(self): """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" self._ordered_callbacks = {} # Fill the ordered callbacks dictionary. for extension in self._extensions: for callback_type, callbacks in extension.get_callbacks().items(): if callback_type not in self._ordered_callbacks: self._ordered_callbacks[callback_type] = [] self._ordered_callbacks[callback_type].extend(callbacks) # Sort each callback list. for callback_type, callbacks in self._ordered_callbacks.items(): # Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were # added will be preserved. self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): if self._is_canceled and self._is_canceled(): raise CanceledException callbacks = self._ordered_callbacks.get(callback_type, []) for cb in callbacks: cb.function(ctx) @contextmanager def patch_extensions(self, ctx: DenoiseContext): if self._is_canceled and self._is_canceled(): raise CanceledException with ExitStack() as exit_stack: for ext in self._extensions: exit_stack.enter_context(ext.patch_extension(ctx)) yield None @contextmanager def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): if self._is_canceled and self._is_canceled(): raise CanceledException # TODO: create weight patch logic in PR with extension which uses it with ExitStack() as exit_stack: for ext in self._extensions: exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) yield None