mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add unit tests for ExtensionsManager and ExtensionBase.
This commit is contained in:
parent
0c56d4a581
commit
83a86abce2
@ -18,6 +18,7 @@ class ExtensionsManager:
|
|||||||
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||||
self._is_canceled = is_canceled
|
self._is_canceled = is_canceled
|
||||||
|
|
||||||
|
# A list of extensions in the order that they were added to the ExtensionsManager.
|
||||||
self._extensions: List[ExtensionBase] = []
|
self._extensions: List[ExtensionBase] = []
|
||||||
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
@ -38,6 +39,8 @@ class ExtensionsManager:
|
|||||||
|
|
||||||
# Sort each callback list.
|
# Sort each callback list.
|
||||||
for callback_type, callbacks in self._ordered_callbacks.items():
|
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||||
|
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
|
||||||
|
# added will be preserved.
|
||||||
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||||
|
|
||||||
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||||
|
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
46
tests/backend/stable_diffusion/extensions/test_base.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtension(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_base_callback_registration():
|
||||||
|
"""Test that a callback can be successfully registered with an extension."""
|
||||||
|
val = 5
|
||||||
|
mock_extension = MockExtension(val)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
|
||||||
|
callbacks = mock_extension.get_callbacks()
|
||||||
|
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
||||||
|
assert len(pre_denoise_loop_cbs) == 1
|
||||||
|
|
||||||
|
# Call the mock callback.
|
||||||
|
pre_denoise_loop_cbs[0].function(mock_ctx)
|
||||||
|
|
||||||
|
# Confirm that the callback ran.
|
||||||
|
assert mock_ctx.step_index == val
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_base_empty_callback_type():
|
||||||
|
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
|
||||||
|
mock_extension = MockExtension(5)
|
||||||
|
|
||||||
|
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
||||||
|
callbacks = mock_extension.get_callbacks()
|
||||||
|
|
||||||
|
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
||||||
|
assert len(post_denoise_loop_cbs) == 0
|
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
112
tests/backend/stable_diffusion/test_extension_manager.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||||
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtension(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
# Note that order is not specified. It should default to 0.
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtensionLate(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_run_callback():
|
||||||
|
"""Test that run_callback runs all callbacks for the given callback type."""
|
||||||
|
|
||||||
|
em = ExtensionsManager()
|
||||||
|
mock_extension_1 = MockExtension(1)
|
||||||
|
em.add_extension(mock_extension_1)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
assert mock_ctx.step_index == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_run_callback_no_callbacks():
|
||||||
|
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
|
||||||
|
em = ExtensionsManager()
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["extension_1", "extension_2"],
|
||||||
|
# Regardless of initialization order, we expect MockExtensionLate to run last.
|
||||||
|
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
|
||||||
|
)
|
||||||
|
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
|
||||||
|
"""Test that run_callback runs callbacks in the correct order."""
|
||||||
|
em = ExtensionsManager()
|
||||||
|
em.add_extension(extension_1)
|
||||||
|
em.add_extension(extension_2)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
assert mock_ctx.step_index == 2
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtensionStableSort(ExtensionBase):
|
||||||
|
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
|
||||||
|
def early(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def middle(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
|
def late(self, ctx: DenoiseContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_stable_sort():
|
||||||
|
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
|
||||||
|
the ExtensionsManager."""
|
||||||
|
|
||||||
|
em = ExtensionsManager()
|
||||||
|
|
||||||
|
mock_extension_1 = MockExtensionStableSort()
|
||||||
|
mock_extension_2 = MockExtensionStableSort()
|
||||||
|
|
||||||
|
em.add_extension(mock_extension_1)
|
||||||
|
em.add_extension(mock_extension_2)
|
||||||
|
|
||||||
|
expected_order = [
|
||||||
|
mock_extension_1.early,
|
||||||
|
mock_extension_2.early,
|
||||||
|
mock_extension_1.middle,
|
||||||
|
mock_extension_2.middle,
|
||||||
|
mock_extension_1.late,
|
||||||
|
mock_extension_2.late,
|
||||||
|
]
|
||||||
|
|
||||||
|
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
|
||||||
|
# desired behaviour.
|
||||||
|
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order
|
Loading…
Reference in New Issue
Block a user