From 2ef3b49a7937d8b0efed6053a71000928b7986bc Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 04:39:15 +0300 Subject: [PATCH] Add run cancelling logic to extension manager --- invokeai/app/invocations/denoise_latents.py | 2 +- .../stable_diffusion/extensions_manager.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 17a79cca90..5b6d945b4e 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -723,7 +723,7 @@ class DenoiseLatentsInvocation(BaseInvocation): @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def _new_invoke(self, context: InvocationContext) -> LatentsOutput: - ext_manager = ExtensionsManager() + ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled) device = TorchDevice.choose_torch_device() dtype = TorchDevice.choose_torch_dtype() diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 213eb5d782..481d1dc358 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -3,14 +3,16 @@ from __future__ import annotations from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from functools import partial -from typing import TYPE_CHECKING, Callable, Dict +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel +from invokeai.app.services.session_processor.session_processor_common import CanceledException + if TYPE_CHECKING: from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext - from invokeai.backend.stable_diffusion.extensions import ExtensionBase + from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase class ExtCallbacksApi(ABC): @@ -71,10 +73,11 @@ class CallbackInjectionPoint: class ExtensionsManager: - def __init__(self): - self.extensions = [] + def __init__(self, is_canceled: Optional[Callable[[], bool]] = None): + self.extensions: List[ExtensionBase] = [] + self._is_canceled = is_canceled - self._callbacks = {} + self._callbacks: Dict[str, CallbackInjectionPoint] = {} self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) def add_extension(self, ext: ExtensionBase): @@ -93,6 +96,11 @@ class ExtensionsManager: raise Exception(f"Unsupported injection type: {inj_info.type}") def call_callback(self, name: str, *args, **kwargs): + # TODO: add to patchers too? + # and if so, should it be only in beginning of function or in for loop + if self._is_canceled and self._is_canceled(): + raise CanceledException + if name in self._callbacks: self._callbacks[name](*args, **kwargs)