mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggestion for how to clean up the ExtensionManager.
This commit is contained in:
parent
499e4d4fde
commit
51750e2060
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
12
invokeai/backend/stable_diffusion/extension_callback_type.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionCallbackType(Enum):
|
||||||
|
SETUP = "setup"
|
||||||
|
PRE_DENOISE_LOOP = "pre_denoise_loop"
|
||||||
|
POST_DENOISE_LOOP = "post_denoise_loop"
|
||||||
|
PRE_STEP = "pre_step"
|
||||||
|
POST_STEP = "post_step"
|
||||||
|
PRE_UNET = "pre_unet"
|
||||||
|
POST_UNET = "post_unet"
|
||||||
|
POST_APPLY_CFG = "post_apply_cfg"
|
34
invokeai/backend/stable_diffusion/extension_manager_2.py
Normal file
34
invokeai/backend/stable_diffusion/extension_manager_2.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import CallbackFunctionWithMetadata, ExtensionBase
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._extensions: list[ExtensionBase] = []
|
||||||
|
self._ordered_callbacks: dict[ExtensionCallbackType, list[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
|
def add_extension(self, extension: ExtensionBase):
|
||||||
|
self._extensions.append(extension)
|
||||||
|
self._regenerate_ordered_callbacks()
|
||||||
|
|
||||||
|
def _regenerate_ordered_callbacks(self):
|
||||||
|
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
|
||||||
|
self._ordered_callbacks = {}
|
||||||
|
|
||||||
|
# Fill the ordered callbacks dictionary.
|
||||||
|
for extension in self._extensions:
|
||||||
|
for callback_type, callbacks in extension.get_callbacks().items():
|
||||||
|
if callback_type not in self._ordered_callbacks:
|
||||||
|
self._ordered_callbacks[callback_type] = []
|
||||||
|
self._ordered_callbacks[callback_type].extend(callbacks)
|
||||||
|
|
||||||
|
# Sort each callback list.
|
||||||
|
for callback_type, callbacks in self._ordered_callbacks.items():
|
||||||
|
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
|
||||||
|
|
||||||
|
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
|
||||||
|
cbs = self._ordered_callbacks.get(callback_type, [])
|
||||||
|
|
||||||
|
for cb in cbs:
|
||||||
|
cb.func(ctx)
|
@ -0,0 +1,60 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Dict, List, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallbackMetadata:
|
||||||
|
callback_type: ExtensionCallbackType
|
||||||
|
order: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CallbackFunctionWithMetadata:
|
||||||
|
metadata: CallbackMetadata
|
||||||
|
func: Callable[[DenoiseContext], None]
|
||||||
|
|
||||||
|
|
||||||
|
# A TypeVar that represents any subclass of ExtensionBase.
|
||||||
|
TExtensionBaseSubclass = TypeVar("TExtensionBaseSubclass", bound="ExtensionBase")
|
||||||
|
|
||||||
|
|
||||||
|
def callback(callback_type: ExtensionCallbackType, order: int = 0):
|
||||||
|
"""A decorator that marks an extension method as a callback."""
|
||||||
|
|
||||||
|
def _decorator(func: Callable[[TExtensionBaseSubclass, DenoiseContext], None]):
|
||||||
|
func._metadata = CallbackMetadata(callback_type, order) # type: ignore
|
||||||
|
return func
|
||||||
|
|
||||||
|
return _decorator
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionBase:
|
||||||
|
def __init__(self):
|
||||||
|
self._callbacks: dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
|
||||||
|
|
||||||
|
# Register all of the callback methods for this instance.
|
||||||
|
for func_name in dir(self):
|
||||||
|
func = getattr(self, func_name)
|
||||||
|
metadata = getattr(func, "_metadata", None)
|
||||||
|
if metadata is not None and isinstance(metadata, CallbackMetadata):
|
||||||
|
if metadata.callback_type not in self._callbacks:
|
||||||
|
self._callbacks[metadata.callback_type] = []
|
||||||
|
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
|
||||||
|
|
||||||
|
def get_callbacks(self):
|
||||||
|
return self._callbacks
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_attention_processor(self, attention_processor_cls: object):
|
||||||
|
yield None
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
|
yield None
|
@ -0,0 +1,44 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import ExtensionBase, callback
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtension(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_base_callback_registration():
|
||||||
|
val = 5
|
||||||
|
mock_extension = MockExtension(val)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
|
||||||
|
callbacks = mock_extension.get_callbacks()
|
||||||
|
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
|
||||||
|
assert len(pre_denoise_loop_cbs) == 1
|
||||||
|
|
||||||
|
# Call the mock callback.
|
||||||
|
pre_denoise_loop_cbs[0].func(mock_ctx)
|
||||||
|
|
||||||
|
# Confirm that the callback ran.
|
||||||
|
assert mock_ctx.step_index == val
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_base_empty_callback_type():
|
||||||
|
mock_extension = MockExtension(5)
|
||||||
|
|
||||||
|
# There should be no callbacks registered for POST_DENOISE_LOOP.
|
||||||
|
callbacks = mock_extension.get_callbacks()
|
||||||
|
|
||||||
|
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
|
||||||
|
assert len(post_denoise_loop_cbs) == 0
|
73
tests/backend/stable_diffusion/test_extension_manager_2.py
Normal file
73
tests/backend/stable_diffusion/test_extension_manager_2.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.stable_diffusion.extension_manager_2 import ExtensionManager
|
||||||
|
from invokeai.backend.stable_diffusion.extensions.extension_base_2 import (
|
||||||
|
ExtensionBase,
|
||||||
|
callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtension(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass for testing purposes."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
# Note that order is not specified. It should default to 0.
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
class MockExtensionLate(ExtensionBase):
|
||||||
|
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
|
||||||
|
|
||||||
|
def __init__(self, x: int):
|
||||||
|
super().__init__()
|
||||||
|
self._x = x
|
||||||
|
|
||||||
|
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
|
||||||
|
def set_step_index(self, ctx: DenoiseContext):
|
||||||
|
ctx.step_index = self._x
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_run_callback():
|
||||||
|
"""Test that run_callback runs all callbacks for the given callback type."""
|
||||||
|
|
||||||
|
em = ExtensionManager()
|
||||||
|
mock_extension_1 = MockExtension(1)
|
||||||
|
em.add_extension(mock_extension_1)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
assert mock_ctx.step_index == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_extension_manager_run_callback_no_callbacks():
|
||||||
|
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
|
||||||
|
em = ExtensionManager()
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["extension_1", "extension_2"],
|
||||||
|
# Regardless of initialization order, we expect MockExtensionLate to run last.
|
||||||
|
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
|
||||||
|
)
|
||||||
|
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
|
||||||
|
"""Test that run_callback runs callbacks in the correct order."""
|
||||||
|
em = ExtensionManager()
|
||||||
|
em.add_extension(extension_1)
|
||||||
|
em.add_extension(extension_2)
|
||||||
|
|
||||||
|
mock_ctx = mock.MagicMock()
|
||||||
|
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
|
||||||
|
|
||||||
|
assert mock_ctx.step_index == 2
|
Loading…
Reference in New Issue
Block a user