mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggestion for how to clean up the ExtensionManager.
This commit is contained in:
parent
499e4d4fde
commit
51750e2060
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
@ -0,0 +1,12 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ExtensionCallbackType(Enum):
|
||||
SETUP = "setup"
|
||||
PRE_DENOISE_LOOP = "pre_denoise_loop"
|
||||
POST_DENOISE_LOOP = "post_denoise_loop"
|
||||
PRE_STEP = "pre_step"
|
||||
POST_STEP = "post_step"
|
||||
PRE_UNET = "pre_unet"
|
||||
POST_UNET = "post_unet"
|
||||
POST_APPLY_CFG = "post_apply_cfg"
|
34
invokeai/backend/stable_diffusion/extension_manager_2.py
Normal file
34
invokeai/backend/stable_diffusion/extension_manager_2.py
Normal file
@ -0,0 +1,34 @@
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import CallbackFunctionWithMetadata, ExtensionBase
|
||||
|
||||
|
||||
class ExtensionManager:
|
||||
def __init__(self):
|
||||
self._extensions: list[ExtensionBase] = []
|
||||
self._ordered_callbacks: dict[ExtensionCallbackType, list[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
def add_extension(self, extension: ExtensionBase):
|
||||
self._extensions.append(extension)
|
||||
self._regenerate_ordered_callbacks()
|
||||
|
||||
def _regenerate_ordered_callbacks(self):
|
||||
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||
self._ordered_callbacks = {}
|
||||
|
||||
# Fill the ordered callbacks dictionary.
|
||||
for extension in self._extensions:
|
||||
for callback_type, callbacks in extension.get_callbacks().items():
|
||||
if callback_type not in self._ordered_callbacks:
|
||||
self._ordered_callbacks[callback_type] = []
|
||||
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||
|
||||
# Sort each callback list.
|
||||
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||
|
||||
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||
cbs = self._ordered_callbacks.get(callback_type, [])
|
||||
|
||||
for cb in cbs:
|
||||
cb.func(ctx)
|
@ -0,0 +1,60 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, TypeVar
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackMetadata:
|
||||
callback_type: ExtensionCallbackType
|
||||
order: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class CallbackFunctionWithMetadata:
|
||||
metadata: CallbackMetadata
|
||||
func: Callable[[DenoiseContext], None]
|
||||
|
||||
|
||||
# A TypeVar that represents any subclass of ExtensionBase.
|
||||
TExtensionBaseSubclass = TypeVar("TExtensionBaseSubclass", bound="ExtensionBase")
|
||||
|
||||
|
||||
def callback(callback_type: ExtensionCallbackType, order: int = 0):
|
||||
"""A decorator that marks an extension method as a callback."""
|
||||
|
||||
def _decorator(func: Callable[[TExtensionBaseSubclass, DenoiseContext], None]):
|
||||
func._metadata = CallbackMetadata(callback_type, order) # type: ignore
|
||||
return func
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class ExtensionBase:
|
||||
def __init__(self):
|
||||
self._callbacks: dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||
|
||||
# Register all of the callback methods for this instance.
|
||||
for func_name in dir(self):
|
||||
func = getattr(self, func_name)
|
||||
metadata = getattr(func, "_metadata", None)
|
||||
if metadata is not None and isinstance(metadata, CallbackMetadata):
|
||||
if metadata.callback_type not in self._callbacks:
|
||||
self._callbacks[metadata.callback_type] = []
|
||||
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
|
||||
|
||||
def get_callbacks(self):
|
||||
return self._callbacks
|
||||
|
||||
@contextmanager
|
||||
def patch_attention_processor(self, attention_processor_cls: object):
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
yield None
|
@ -0,0 +1,44 @@
|
||||
from unittest import mock
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import ExtensionBase, callback
|
||||
|
||||
|
||||
class MockExtension(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass for testing purposes."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
def test_extension_base_callback_registration():
|
||||
val = 5
|
||||
mock_extension = MockExtension(val)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
|
||||
callbacks = mock_extension.get_callbacks()
|
||||
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
||||
assert len(pre_denoise_loop_cbs) == 1
|
||||
|
||||
# Call the mock callback.
|
||||
pre_denoise_loop_cbs[0].func(mock_ctx)
|
||||
|
||||
# Confirm that the callback ran.
|
||||
assert mock_ctx.step_index == val
|
||||
|
||||
|
||||
def test_extension_base_empty_callback_type():
|
||||
mock_extension = MockExtension(5)
|
||||
|
||||
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
||||
callbacks = mock_extension.get_callbacks()
|
||||
|
||||
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
||||
assert len(post_denoise_loop_cbs) == 0
|
73
tests/backend/stable_diffusion/test_extension_manager_2.py
Normal file
73
tests/backend/stable_diffusion/test_extension_manager_2.py
Normal file
@ -0,0 +1,73 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extension_manager_2 import ExtensionManager
|
||||
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import (
|
||||
ExtensionBase,
|
||||
callback,
|
||||
)
|
||||
|
||||
|
||||
class MockExtension(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass for testing purposes."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
# Note that order is not specified. It should default to 0.
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
class MockExtensionLate(ExtensionBase):
|
||||
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
|
||||
|
||||
def __init__(self, x: int):
|
||||
super().__init__()
|
||||
self._x = x
|
||||
|
||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||
def set_step_index(self, ctx: DenoiseContext):
|
||||
ctx.step_index = self._x
|
||||
|
||||
|
||||
def test_extension_manager_run_callback():
|
||||
"""Test that run_callback runs all callbacks for the given callback type."""
|
||||
|
||||
em = ExtensionManager()
|
||||
mock_extension_1 = MockExtension(1)
|
||||
em.add_extension(mock_extension_1)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
assert mock_ctx.step_index == 1
|
||||
|
||||
|
||||
def test_extension_manager_run_callback_no_callbacks():
|
||||
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
|
||||
em = ExtensionManager()
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["extension_1", "extension_2"],
|
||||
# Regardless of initialization order, we expect MockExtensionLate to run last.
|
||||
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
|
||||
)
|
||||
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
|
||||
"""Test that run_callback runs callbacks in the correct order."""
|
||||
em = ExtensionManager()
|
||||
em.add_extension(extension_1)
|
||||
em.add_extension(extension_2)
|
||||
|
||||
mock_ctx = mock.MagicMock()
|
||||
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||
|
||||
assert mock_ctx.step_index == 2
|
Loading…
Reference in New Issue
Block a user