from unittest import mock import pytest from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager class MockExtension(ExtensionBase): """A mock ExtensionBase subclass for testing purposes.""" def __init__(self, x: int): super().__init__() self._x = x # Note that order is not specified. It should default to 0. @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def set_step_index(self, ctx: DenoiseContext): ctx.step_index = self._x class MockExtensionLate(ExtensionBase): """A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback.""" def __init__(self, x: int): super().__init__() self._x = x @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) def set_step_index(self, ctx: DenoiseContext): ctx.step_index = self._x def test_extension_manager_run_callback(): """Test that run_callback runs all callbacks for the given callback type.""" em = ExtensionsManager() mock_extension_1 = MockExtension(1) em.add_extension(mock_extension_1) mock_ctx = mock.MagicMock() em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) assert mock_ctx.step_index == 1 def test_extension_manager_run_callback_no_callbacks(): """Test that run_callback does not raise an error when there are no callbacks for the given callback type.""" em = ExtensionsManager() mock_ctx = mock.MagicMock() em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) @pytest.mark.parametrize( ["extension_1", "extension_2"], # Regardless of initialization order, we expect MockExtensionLate to run last. [(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))], ) def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase): """Test that run_callback runs callbacks in the correct order.""" em = ExtensionsManager() em.add_extension(extension_1) em.add_extension(extension_2) mock_ctx = mock.MagicMock() em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx) assert mock_ctx.step_index == 2 class MockExtensionStableSort(ExtensionBase): """A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value.""" @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000) def early(self, ctx: DenoiseContext): pass @callback(ExtensionCallbackType.PRE_DENOISE_LOOP) def middle(self, ctx: DenoiseContext): pass @callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000) def late(self, ctx: DenoiseContext): pass def test_extension_manager_stable_sort(): """Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to the ExtensionsManager.""" em = ExtensionsManager() mock_extension_1 = MockExtensionStableSort() mock_extension_2 = MockExtensionStableSort() em.add_extension(mock_extension_1) em.add_extension(mock_extension_2) expected_order = [ mock_extension_1.early, mock_extension_2.early, mock_extension_1.middle, mock_extension_2.middle, mock_extension_1.late, mock_extension_2.late, ] # It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the # desired behaviour. assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order