InvokeAI/tests/backend/stable_diffusion/test_extension_manager.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

113 lines
3.8 KiB
Python
Raw Normal View History

from unittest import mock
import pytest
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
class MockExtension(ExtensionBase):
"""A mock ExtensionBase subclass for testing purposes."""
def __init__(self, x: int):
super().__init__()
self._x = x
# Note that order is not specified. It should default to 0.
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
class MockExtensionLate(ExtensionBase):
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
def __init__(self, x: int):
super().__init__()
self._x = x
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
def test_extension_manager_run_callback():
"""Test that run_callback runs all callbacks for the given callback type."""
em = ExtensionsManager()
mock_extension_1 = MockExtension(1)
em.add_extension(mock_extension_1)
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
assert mock_ctx.step_index == 1
def test_extension_manager_run_callback_no_callbacks():
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
em = ExtensionsManager()
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
@pytest.mark.parametrize(
["extension_1", "extension_2"],
# Regardless of initialization order, we expect MockExtensionLate to run last.
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
)
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
"""Test that run_callback runs callbacks in the correct order."""
em = ExtensionsManager()
em.add_extension(extension_1)
em.add_extension(extension_2)
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
assert mock_ctx.step_index == 2
class MockExtensionStableSort(ExtensionBase):
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
def early(self, ctx: DenoiseContext):
pass
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def middle(self, ctx: DenoiseContext):
pass
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def late(self, ctx: DenoiseContext):
pass
def test_extension_manager_stable_sort():
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
the ExtensionsManager."""
em = ExtensionsManager()
mock_extension_1 = MockExtensionStableSort()
mock_extension_2 = MockExtensionStableSort()
em.add_extension(mock_extension_1)
em.add_extension(mock_extension_2)
expected_order = [
mock_extension_1.early,
mock_extension_2.early,
mock_extension_1.middle,
mock_extension_2.middle,
mock_extension_1.late,
mock_extension_2.late,
]
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
# desired behaviour.
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order