mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from unittest import mock
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
|
|
|
|
class MockExtension(ExtensionBase):
|
|
"""A mock ExtensionBase subclass for testing purposes."""
|
|
|
|
def __init__(self, x: int):
|
|
super().__init__()
|
|
self._x = x
|
|
|
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
|
def set_step_index(self, ctx: DenoiseContext):
|
|
ctx.step_index = self._x
|
|
|
|
|
|
def test_extension_base_callback_registration():
|
|
"""Test that a callback can be successfully registered with an extension."""
|
|
val = 5
|
|
mock_extension = MockExtension(val)
|
|
|
|
mock_ctx = mock.MagicMock()
|
|
|
|
callbacks = mock_extension.get_callbacks()
|
|
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
|
assert len(pre_denoise_loop_cbs) == 1
|
|
|
|
# Call the mock callback.
|
|
pre_denoise_loop_cbs[0].function(mock_ctx)
|
|
|
|
# Confirm that the callback ran.
|
|
assert mock_ctx.step_index == val
|
|
|
|
|
|
def test_extension_base_empty_callback_type():
|
|
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
|
|
mock_extension = MockExtension(5)
|
|
|
|
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
|
callbacks = mock_extension.get_callbacks()
|
|
|
|
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
|
assert len(post_denoise_loop_cbs) == 0
|