Add run cancelling logic to extension manager

This commit is contained in:
Sergey Borisov 2024-07-17 04:39:15 +03:00
parent 3f79467f7b
commit 2ef3b49a79
2 changed files with 14 additions and 6 deletions

View File

@ -723,7 +723,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers. @SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _new_invoke(self, context: InvocationContext) -> LatentsOutput: def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
ext_manager = ExtensionsManager() ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
device = TorchDevice.choose_torch_device() device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype() dtype = TorchDevice.choose_torch_dtype()

View File

@ -3,14 +3,16 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Callable, Dict from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch import torch
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extensions import ExtensionBase from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
class ExtCallbacksApi(ABC): class ExtCallbacksApi(ABC):
@ -71,10 +73,11 @@ class CallbackInjectionPoint:
class ExtensionsManager: class ExtensionsManager:
def __init__(self): def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
self.extensions = [] self.extensions: List[ExtensionBase] = []
self._is_canceled = is_canceled
self._callbacks = {} self._callbacks: Dict[str, CallbackInjectionPoint] = {}
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback) self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
def add_extension(self, ext: ExtensionBase): def add_extension(self, ext: ExtensionBase):
@ -93,6 +96,11 @@ class ExtensionsManager:
raise Exception(f"Unsupported injection type: {inj_info.type}") raise Exception(f"Unsupported injection type: {inj_info.type}")
def call_callback(self, name: str, *args, **kwargs): def call_callback(self, name: str, *args, **kwargs):
# TODO: add to patchers too?
# and if so, should it be only in beginning of function or in for loop
if self._is_canceled and self._is_canceled():
raise CanceledException
if name in self._callbacks: if name in self._callbacks:
self._callbacks[name](*args, **kwargs) self._callbacks[name](*args, **kwargs)