2024-07-12 17:31:26 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from contextlib import ExitStack, contextmanager
|
2024-07-29 21:34:37 +00:00
|
|
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
|
|
|
2024-07-17 01:39:15 +00:00
|
|
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
|
|
|
|
2024-07-12 17:31:26 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
2024-07-18 20:49:44 +00:00
|
|
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
|
|
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ExtensionsManager:
|
2024-07-17 01:39:15 +00:00
|
|
|
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
|
|
|
self._is_canceled = is_canceled
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-19 18:05:27 +00:00
|
|
|
# A list of extensions in the order that they were added to the ExtensionsManager.
|
2024-07-18 20:49:44 +00:00
|
|
|
self._extensions: List[ExtensionBase] = []
|
|
|
|
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
def add_extension(self, extension: ExtensionBase):
|
|
|
|
self._extensions.append(extension)
|
|
|
|
self._regenerate_ordered_callbacks()
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
def _regenerate_ordered_callbacks(self):
|
|
|
|
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
|
|
|
self._ordered_callbacks = {}
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
# Fill the ordered callbacks dictionary.
|
|
|
|
for extension in self._extensions:
|
|
|
|
for callback_type, callbacks in extension.get_callbacks().items():
|
|
|
|
if callback_type not in self._ordered_callbacks:
|
|
|
|
self._ordered_callbacks[callback_type] = []
|
|
|
|
self._ordered_callbacks[callback_type].extend(callbacks)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
# Sort each callback list.
|
|
|
|
for callback_type, callbacks in self._ordered_callbacks.items():
|
2024-07-19 18:05:27 +00:00
|
|
|
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
|
|
|
|
# added will be preserved.
|
2024-07-18 20:49:44 +00:00
|
|
|
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
2024-07-17 01:39:15 +00:00
|
|
|
if self._is_canceled and self._is_canceled():
|
|
|
|
raise CanceledException
|
|
|
|
|
2024-07-18 20:49:44 +00:00
|
|
|
callbacks = self._ordered_callbacks.get(callback_type, [])
|
|
|
|
for cb in callbacks:
|
|
|
|
cb.function(ctx)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
@contextmanager
|
2024-07-23 13:18:04 +00:00
|
|
|
def patch_extensions(self, ctx: DenoiseContext):
|
2024-07-19 20:17:01 +00:00
|
|
|
if self._is_canceled and self._is_canceled():
|
|
|
|
raise CanceledException
|
|
|
|
|
2024-07-16 17:03:29 +00:00
|
|
|
with ExitStack() as exit_stack:
|
2024-07-18 20:49:44 +00:00
|
|
|
for ext in self._extensions:
|
2024-07-23 13:18:04 +00:00
|
|
|
exit_stack.enter_context(ext.patch_extension(ctx))
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
yield None
|
|
|
|
|
|
|
|
@contextmanager
|
2024-07-21 15:31:10 +00:00
|
|
|
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
2024-07-19 20:17:01 +00:00
|
|
|
if self._is_canceled and self._is_canceled():
|
|
|
|
raise CanceledException
|
|
|
|
|
2024-07-29 21:34:37 +00:00
|
|
|
original_weights: Dict[str, torch.Tensor] = {}
|
|
|
|
if cached_weights:
|
|
|
|
original_weights.update(cached_weights)
|
2024-07-24 02:07:29 +00:00
|
|
|
|
|
|
|
try:
|
2024-07-27 01:25:15 +00:00
|
|
|
with ExitStack() as exit_stack:
|
|
|
|
for ext in self._extensions:
|
2024-07-29 21:34:37 +00:00
|
|
|
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
2024-07-27 01:25:15 +00:00
|
|
|
|
|
|
|
yield None
|
2024-07-24 02:07:29 +00:00
|
|
|
|
|
|
|
finally:
|
|
|
|
with torch.no_grad():
|
2024-07-29 21:34:37 +00:00
|
|
|
for param_key, weight in original_weights.items():
|
2024-07-24 02:07:29 +00:00
|
|
|
unet.get_parameter(param_key).copy_(weight)
|