InvokeAI/invokeai/backend/stable_diffusion/extensions_manager.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

98 lines
4.1 KiB
Python
Raw Normal View History

2024-07-12 17:31:26 +00:00
from __future__ import annotations
from contextlib import ExitStack, contextmanager
2024-07-24 02:07:29 +00:00
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
2024-07-12 17:31:26 +00:00
import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
2024-07-12 17:31:26 +00:00
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
2024-07-12 17:31:26 +00:00
class ExtensionsManager:
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
self._is_canceled = is_canceled
2024-07-12 17:31:26 +00:00
# A list of extensions in the order that they were added to the ExtensionsManager.
self._extensions: List[ExtensionBase] = []
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
2024-07-12 17:31:26 +00:00
def add_extension(self, extension: ExtensionBase):
self._extensions.append(extension)
self._regenerate_ordered_callbacks()
2024-07-12 17:31:26 +00:00
def _regenerate_ordered_callbacks(self):
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
self._ordered_callbacks = {}
2024-07-12 17:31:26 +00:00
# Fill the ordered callbacks dictionary.
for extension in self._extensions:
for callback_type, callbacks in extension.get_callbacks().items():
if callback_type not in self._ordered_callbacks:
self._ordered_callbacks[callback_type] = []
self._ordered_callbacks[callback_type].extend(callbacks)
2024-07-12 17:31:26 +00:00
# Sort each callback list.
for callback_type, callbacks in self._ordered_callbacks.items():
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
# added will be preserved.
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
2024-07-12 17:31:26 +00:00
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
callbacks = self._ordered_callbacks.get(callback_type, [])
for cb in callbacks:
cb.function(ctx)
2024-07-12 17:31:26 +00:00
@contextmanager
def patch_extensions(self, ctx: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
2024-07-16 17:03:29 +00:00
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_extension(ctx))
2024-07-12 17:31:26 +00:00
yield None
@contextmanager
2024-07-21 15:31:10 +00:00
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled():
raise CanceledException
2024-07-24 02:07:29 +00:00
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
exit_stack = ExitStack()
try:
2024-07-21 15:31:10 +00:00
for ext in self._extensions:
2024-07-24 02:07:29 +00:00
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
if res is None:
continue
ext_modified_cached_weights, ext_modified_weights = res
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
2024-07-21 15:31:10 +00:00
yield None
2024-07-24 02:07:29 +00:00
finally:
exit_stack.close()
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
unet.get_parameter(param_key).copy_(weight)