Suggestion for how to clean up the ExtensionManager.

This commit is contained in:
Ryan Dick 2024-07-15 15:39:16 -04:00
parent 499e4d4fde
commit 51750e2060
5 changed files with 223 additions and 0 deletions

View File

@ -0,0 +1,12 @@
from enum import Enum
class ExtensionCallbackType(Enum):
SETUP = "setup"
PRE_DENOISE_LOOP = "pre_denoise_loop"
POST_DENOISE_LOOP = "post_denoise_loop"
PRE_STEP = "pre_step"
POST_STEP = "post_step"
PRE_UNET = "pre_unet"
POST_UNET = "post_unet"
POST_APPLY_CFG = "post_apply_cfg"

View File

@ -0,0 +1,34 @@
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import CallbackFunctionWithMetadata, ExtensionBase
class ExtensionManager:
def __init__(self):
self._extensions: list[ExtensionBase] = []
self._ordered_callbacks: dict[ExtensionCallbackType, list[CallbackFunctionWithMetadata]] = {}
def add_extension(self, extension: ExtensionBase):
self._extensions.append(extension)
self._regenerate_ordered_callbacks()
def _regenerate_ordered_callbacks(self):
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
self._ordered_callbacks = {}
# Fill the ordered callbacks dictionary.
for extension in self._extensions:
for callback_type, callbacks in extension.get_callbacks().items():
if callback_type not in self._ordered_callbacks:
self._ordered_callbacks[callback_type] = []
self._ordered_callbacks[callback_type].extend(callbacks)
# Sort each callback list.
for callback_type, callbacks in self._ordered_callbacks.items():
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
cbs = self._ordered_callbacks.get(callback_type, [])
for cb in cbs:
cb.func(ctx)

View File

@ -0,0 +1,60 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable, Dict, List, TypeVar
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
@dataclass
class CallbackMetadata:
callback_type: ExtensionCallbackType
order: int
@dataclass
class CallbackFunctionWithMetadata:
metadata: CallbackMetadata
func: Callable[[DenoiseContext], None]
# A TypeVar that represents any subclass of ExtensionBase.
TExtensionBaseSubclass = TypeVar("TExtensionBaseSubclass", bound="ExtensionBase")
def callback(callback_type: ExtensionCallbackType, order: int = 0):
"""A decorator that marks an extension method as a callback."""
def _decorator(func: Callable[[TExtensionBaseSubclass, DenoiseContext], None]):
func._metadata = CallbackMetadata(callback_type, order) # type: ignore
return func
return _decorator
class ExtensionBase:
def __init__(self):
self._callbacks: dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
# Register all of the callback methods for this instance.
for func_name in dir(self):
func = getattr(self, func_name)
metadata = getattr(func, "_metadata", None)
if metadata is not None and isinstance(metadata, CallbackMetadata):
if metadata.callback_type not in self._callbacks:
self._callbacks[metadata.callback_type] = []
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
def get_callbacks(self):
return self._callbacks
@contextmanager
def patch_attention_processor(self, attention_processor_cls: object):
yield None
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
yield None

View File

@ -0,0 +1,44 @@
from unittest import mock
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import ExtensionBase, callback
class MockExtension(ExtensionBase):
"""A mock ExtensionBase subclass for testing purposes."""
def __init__(self, x: int):
super().__init__()
self._x = x
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
def test_extension_base_callback_registration():
val = 5
mock_extension = MockExtension(val)
mock_ctx = mock.MagicMock()
callbacks = mock_extension.get_callbacks()
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
assert len(pre_denoise_loop_cbs) == 1
# Call the mock callback.
pre_denoise_loop_cbs[0].func(mock_ctx)
# Confirm that the callback ran.
assert mock_ctx.step_index == val
def test_extension_base_empty_callback_type():
mock_extension = MockExtension(5)
# There should be no callbacks registered for POST_DENOISE_LOOP.
callbacks = mock_extension.get_callbacks()
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
assert len(post_denoise_loop_cbs) == 0

View File

@ -0,0 +1,73 @@
from unittest import mock
import pytest
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extension_manager_2 import ExtensionManager
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import (
ExtensionBase,
callback,
)
class MockExtension(ExtensionBase):
"""A mock ExtensionBase subclass for testing purposes."""
def __init__(self, x: int):
super().__init__()
self._x = x
# Note that order is not specified. It should default to 0.
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
class MockExtensionLate(ExtensionBase):
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
def __init__(self, x: int):
super().__init__()
self._x = x
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def set_step_index(self, ctx: DenoiseContext):
ctx.step_index = self._x
def test_extension_manager_run_callback():
"""Test that run_callback runs all callbacks for the given callback type."""
em = ExtensionManager()
mock_extension_1 = MockExtension(1)
em.add_extension(mock_extension_1)
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
assert mock_ctx.step_index == 1
def test_extension_manager_run_callback_no_callbacks():
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
em = ExtensionManager()
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
@pytest.mark.parametrize(
["extension_1", "extension_2"],
# Regardless of initialization order, we expect MockExtensionLate to run last.
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
)
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
"""Test that run_callback runs callbacks in the correct order."""
em = ExtensionManager()
em.add_extension(extension_1)
em.add_extension(extension_2)
mock_ctx = mock.MagicMock()
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
assert mock_ctx.step_index == 2