diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1552fb5dd7..f42a065e82 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -18,6 +18,7 @@ class ExtensionsManager: def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): self._is_canceled = is_canceled + # A list of extensions in the order that they were added to the ExtensionsManager. self._extensions: List[ExtensionBase] = [] self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {} @@ -38,6 +39,8 @@ class ExtensionsManager: # Sort each callback list. for callback_type, callbacks in self._ordered_callbacks.items(): + # Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were + # added will be preserved. self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): diff --git a/tests/backend/stable_diffusion/extensions/test_base.py b/tests/backend/stable_diffusion/extensions/test_base.py new file mode 100644 index 0000000000..d024c551a2 --- /dev/null +++ b/tests/backend/stable_diffusion/extensions/test_base.py @@ -0,0 +1,46 @@ +from unittest import mock + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_base_callback_registration(): + """Test that a callback can be successfully registered with an extension.""" + val = 5 + mock_extension = MockExtension(val) + + mock_ctx = mock.MagicMock() + + callbacks = mock_extension.get_callbacks() + pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, []) + assert len(pre_denoise_loop_cbs) == 1 + + # Call the mock callback. + pre_denoise_loop_cbs[0].function(mock_ctx) + + # Confirm that the callback ran. + assert mock_ctx.step_index == val + + +def test_extension_base_empty_callback_type(): + """Test that an empty list is returned when no callbacks are registered for a given callback type.""" + mock_extension = MockExtension(5) + + # There should be no callbacks registered for POST_DENOISE_LOOP. + callbacks = mock_extension.get_callbacks() + + post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, []) + assert len(post_denoise_loop_cbs) == 0 diff --git a/tests/backend/stable_diffusion/test_extension_manager.py b/tests/backend/stable_diffusion/test_extension_manager.py new file mode 100644 index 0000000000..889f8316e5 --- /dev/null +++ b/tests/backend/stable_diffusion/test_extension_manager.py @@ -0,0 +1,112 @@ +from unittest import mock + +import pytest + +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class MockExtension(ExtensionBase): + """A mock ExtensionBase subclass for testing purposes.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + # Note that order is not specified. It should default to 0. + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +class MockExtensionLate(ExtensionBase): + """A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback.""" + + def __init__(self, x: int): + super().__init__() + self._x = x + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def set_step_index(self, ctx: DenoiseContext): + ctx.step_index = self._x + + +def test_extension_manager_run_callback(): + """Test that run_callback runs all callbacks for the given callback type.""" + + em = ExtensionsManager() + mock_extension_1 = MockExtension(1) + em.add_extension(mock_extension_1) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 1 + + +def test_extension_manager_run_callback_no_callbacks(): + """Test that run_callback does not raise an error when there are no callbacks for the given callback type.""" + em = ExtensionsManager() + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + +@pytest.mark.parametrize( + ["extension_1", "extension_2"], + # Regardless of initialization order, we expect MockExtensionLate to run last. + [(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))], +) +def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase): + """Test that run_callback runs callbacks in the correct order.""" + em = ExtensionsManager() + em.add_extension(extension_1) + em.add_extension(extension_2) + + mock_ctx = mock.MagicMock() + em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) + + assert mock_ctx.step_index == 2 + + +class MockExtensionStableSort(ExtensionBase): + """A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value.""" + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000) + def early(self, ctx: DenoiseContext): + pass + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) + def middle(self, ctx: DenoiseContext): + pass + + @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) + def late(self, ctx: DenoiseContext): + pass + + +def test_extension_manager_stable_sort(): + """Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to + the ExtensionsManager.""" + + em = ExtensionsManager() + + mock_extension_1 = MockExtensionStableSort() + mock_extension_2 = MockExtensionStableSort() + + em.add_extension(mock_extension_1) + em.add_extension(mock_extension_2) + + expected_order = [ + mock_extension_1.early, + mock_extension_2.early, + mock_extension_1.middle, + mock_extension_2.middle, + mock_extension_1.late, + mock_extension_2.late, + ] + + # It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the + # desired behaviour. + assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order