Add run cancelling logic to extension manager

This commit is contained in:
Sergey Borisov 2024-07-17 04:39:15 +03:00
parent 3f79467f7b
commit 2ef3b49a79
2 changed files with 14 additions and 6 deletions

View File

@ -723,7 +723,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
ext_manager = ExtensionsManager()
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()

View File

@ -3,14 +3,16 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from functools import partial
from typing import TYPE_CHECKING, Callable, Dict
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
class ExtCallbacksApi(ABC):
@ -71,10 +73,11 @@ class CallbackInjectionPoint:
class ExtensionsManager:
def __init__(self):
self.extensions = []
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
self.extensions: List[ExtensionBase] = []
self._is_canceled = is_canceled
self._callbacks = {}
self._callbacks: Dict[str, CallbackInjectionPoint] = {}
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
def add_extension(self, ext: ExtensionBase):
@ -93,6 +96,11 @@ class ExtensionsManager:
raise Exception(f"Unsupported injection type: {inj_info.type}")
def call_callback(self, name: str, *args, **kwargs):
# TODO: add to patchers too?
# and if so, should it be only in beginning of function or in for loop
if self._is_canceled and self._is_canceled():
raise CanceledException
if name in self._callbacks:
self._callbacks[name](*args, **kwargs)