diff --git a/invokeai/backend/stable_diffusion/extension_callback_type.py b/invokeai/backend/stable_diffusion/extension_callback_type.py new file mode 100644 index 0000000000..aaefbd7ed0 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extension_callback_type.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class ExtensionCallbackType(Enum): + SETUP = "setup" + PRE_DENOISE_LOOP = "pre_denoise_loop" + POST_DENOISE_LOOP = "post_denoise_loop" + PRE_STEP = "pre_step" + POST_STEP = "post_step" + PRE_UNET = "pre_unet" + POST_UNET = "post_unet" + POST_APPLY_CFG = "post_apply_cfg" diff --git a/invokeai/backend/stable_diffusion/extension_manager_2.py b/invokeai/backend/stable_diffusion/extension_manager_2.py new file mode 100644 index 0000000000..016811ba8d --- /dev/null +++ b/invokeai/backend/stable_diffusion/extension_manager_2.py @@ -0,0 +1,34 @@ +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.extension_base_2 import CallbackFunctionWithMetadata, ExtensionBase + + +class ExtensionManager: + def __init__(self): + self._extensions: list[ExtensionBase] = [] + self._ordered_callbacks: dict[ExtensionCallbackType, list[CallbackFunctionWithMetadata]] = {} + + def add_extension(self, extension: ExtensionBase): + self._extensions.append(extension) + self._regenerate_ordered_callbacks() + + def _regenerate_ordered_callbacks(self): + """Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added.""" + self._ordered_callbacks = {} + + # Fill the ordered callbacks dictionary. + for extension in self._extensions: + for callback_type, callbacks in extension.get_callbacks().items(): + if callback_type not in self._ordered_callbacks: + self._ordered_callbacks[callback_type] = [] + self._ordered_callbacks[callback_type].extend(callbacks) + + # Sort each callback list. + for callback_type, callbacks in self._ordered_callbacks.items(): + self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) + + def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): + cbs = self._ordered_callbacks.get(callback_type, []) + + for cb in cbs: + cb.func(ctx) diff --git a/invokeai/backend/stable_diffusion/extensions/extension_base_2.py b/invokeai/backend/stable_diffusion/extensions/extension_base_2.py new file mode 100644 index 0000000000..cc6202e6d7 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/extension_base_2.py @@ -0,0 +1,60 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Callable, Dict, List, TypeVar + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType + + +@dataclass +class CallbackMetadata: + callback_type: ExtensionCallbackType + order: int + + +@dataclass +class CallbackFunctionWithMetadata: + metadata: CallbackMetadata + func: Callable[[DenoiseContext], None] + + +# A TypeVar that represents any subclass of ExtensionBase. +TExtensionBaseSubclass = TypeVar("TExtensionBaseSubclass", bound="ExtensionBase") + + +def callback(callback_type: ExtensionCallbackType, order: int = 0): + """A decorator that marks an extension method as a callback.""" + + def _decorator(func: Callable[[TExtensionBaseSubclass, DenoiseContext], None]): + func._metadata = CallbackMetadata(callback_type, order) # type: ignore + return func + + return _decorator + + +class ExtensionBase: + def __init__(self): + self._callbacks: dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} + + # Register all of the callback methods for this instance. + for func_name in dir(self): + func = getattr(self, func_name) + metadata = getattr(func, "_metadata", None) + if metadata is not None and isinstance(metadata, CallbackMetadata): + if metadata.callback_type not in self._callbacks: + self._callbacks[metadata.callback_type] = [] + self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func)) + + def get_callbacks(self): + return self._callbacks + + @contextmanager + def patch_attention_processor(self, attention_processor_cls: object): + yield None + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + yield None diff --git a/tests/backend/stable_diffusion/extensions/test_extension_base_2.py b/tests/backend/stable_diffusion/extensions/test_extension_base_2.py new file mode 100644 index 0000000000..1e95074e98 --- /dev/null +++ b/tests/backend/stable_diffusion/extensions/test_extension_base_2.py @@ -0,0 +1,44 @@ +from unittest import mock + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.extension_base_2 import ExtensionBase, callback + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_base_callback_registration(): + val = 5 + mock_extension = MockExtension(val) + + mock_ctx = mock.MagicMock() + + callbacks = mock_extension.get_callbacks() + pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, []) + assert len(pre_denoise_loop_cbs) == 1 + + # Call the mock callback. + pre_denoise_loop_cbs[0].func(mock_ctx) + + # Confirm that the callback ran. + assert mock_ctx.step_index == val + + +def test_extension_base_empty_callback_type(): + mock_extension = MockExtension(5) + + # There should be no callbacks registered for POST_DENOISE_LOOP. + callbacks = mock_extension.get_callbacks() + + post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, []) + assert len(post_denoise_loop_cbs) == 0 diff --git a/tests/backend/stable_diffusion/test_extension_manager_2.py b/tests/backend/stable_diffusion/test_extension_manager_2.py new file mode 100644 index 0000000000..96c3e3e53b --- /dev/null +++ b/tests/backend/stable_diffusion/test_extension_manager_2.py @@ -0,0 +1,73 @@ +from unittest import mock + +import pytest + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extension_manager_2 import ExtensionManager +from invokeai.backend.stable_diffusion.extensions.extension_base_2 import ( + ExtensionBase, + callback, +) + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + # Note that order is not specified. It should default to 0. + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +class MockExtensionLate(ExtensionBase): + """A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_manager_run_callback(): + """Test that run_callback runs all callbacks for the given callback type.""" + + em = ExtensionManager() + mock_extension_1 = MockExtension(1) + em.add_extension(mock_extension_1) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 1 + + +def test_extension_manager_run_callback_no_callbacks(): + """Test that run_callback does not raise an error when there are no callbacks for the given callback type.""" + em = ExtensionManager() + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + +@pytest.mark.parametrize( + ["extension_1", "extension_2"], + # Regardless of initialization order, we expect MockExtensionLate to run last. + [(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))], +) +def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase): + """Test that run_callback runs callbacks in the correct order.""" + em = ExtensionManager() + em.add_extension(extension_1) + em.add_extension(extension_2) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 2