from __future__ import annotations from contextlib import ExitStack, contextmanager from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set import torch from diffusers import UNet2DConditionModel from invokeai.app.services.session_processor.session_processor_common import CanceledException if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase class ExtensionsManager: def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): self._is_canceled = is_canceled # A list of extensions in the order that they were added to the ExtensionsManager. self._extensions: List[ExtensionBase] = [] self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} def add_extension(self, extension: ExtensionBase): self._extensions.append(extension) self._regenerate_ordered_callbacks() def _regenerate_ordered_callbacks(self): """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" self._ordered_callbacks = {} # Fill the ordered callbacks dictionary. for extension in self._extensions: for callback_type, callbacks in extension.get_callbacks().items(): if callback_type not in self._ordered_callbacks: self._ordered_callbacks[callback_type] = [] self._ordered_callbacks[callback_type].extend(callbacks) # Sort each callback list. for callback_type, callbacks in self._ordered_callbacks.items(): # Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were # added will be preserved. self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): if self._is_canceled and self._is_canceled(): raise CanceledException callbacks = self._ordered_callbacks.get(callback_type, []) for cb in callbacks: cb.function(ctx) @contextmanager def patch_extensions(self, ctx: DenoiseContext): if self._is_canceled and self._is_canceled(): raise CanceledException with ExitStack() as exit_stack: for ext in self._extensions: exit_stack.enter_context(ext.patch_extension(ctx)) yield None @contextmanager def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): if self._is_canceled and self._is_canceled(): raise CanceledException modified_weights: Dict[str, torch.Tensor] = {} modified_cached_weights: Set[str] = set() try: with ExitStack() as exit_stack: for ext in self._extensions: ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context( ext.patch_unet(unet, cached_weights) ) modified_cached_weights.update(ext_modified_cached_weights) # store only first returned weight for each key, because # next extension which changes it, will work with already modified weight for param_key, weight in ext_modified_weights.items(): if param_key in modified_weights: continue modified_weights[param_key] = weight yield None finally: with torch.no_grad(): for param_key in modified_cached_weights: unet.get_parameter(param_key).copy_(cached_weights[param_key]) for param_key, weight in modified_weights.items(): unet.get_parameter(param_key).copy_(weight)