InvokeAI/invokeai/backend/stable_diffusion/extensions/base.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

73 lines
2.8 KiB
Python
Raw Normal View History

2024-07-16 17:03:29 +00:00
from __future__ import annotations
2024-07-12 17:31:26 +00:00
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List
2024-07-12 17:31:26 +00:00
from diffusers import UNet2DConditionModel
2024-07-16 17:03:29 +00:00
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
2024-07-30 00:39:01 +00:00
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
2024-07-16 17:03:29 +00:00
2024-07-12 17:31:26 +00:00
@dataclass
class CallbackMetadata:
callback_type: ExtensionCallbackType
order: int
@dataclass
class CallbackFunctionWithMetadata:
metadata: CallbackMetadata
function: Callable[[DenoiseContext], None]
def callback(callback_type: ExtensionCallbackType, order: int = 0):
def _decorator(function):
function._ext_metadata = CallbackMetadata(
callback_type=callback_type,
order=order,
)
return function
2024-07-12 17:31:26 +00:00
return _decorator
class ExtensionBase:
2024-07-12 21:44:21 +00:00
def __init__(self):
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
# Register all of the callback methods for this instance.
2024-07-12 17:31:26 +00:00
for func_name in dir(self):
func = getattr(self, func_name)
metadata = getattr(func, "_ext_metadata", None)
if metadata is not None and isinstance(metadata, CallbackMetadata):
if metadata.callback_type not in self._callbacks:
self._callbacks[metadata.callback_type] = []
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
def get_callbacks(self):
return self._callbacks
2024-07-12 17:31:26 +00:00
@contextmanager
def patch_extension(self, ctx: DenoiseContext):
2024-07-12 17:31:26 +00:00
yield None
@contextmanager
2024-07-30 00:39:01 +00:00
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
2024-07-30 00:39:01 +00:00
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
by this context manager.
Args:
unet (UNet2DConditionModel): The UNet model on execution device to patch.
2024-07-30 00:39:01 +00:00
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
can access original weights values.
"""
yield