mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add run cancelling logic to extension manager
This commit is contained in:
parent
3f79467f7b
commit
2ef3b49a79
@ -723,7 +723,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||||
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
ext_manager = ExtensionsManager()
|
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
|
||||||
|
|
||||||
device = TorchDevice.choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
@ -3,14 +3,16 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import ExitStack, contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Callable, Dict
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
|
|
||||||
class ExtCallbacksApi(ABC):
|
class ExtCallbacksApi(ABC):
|
||||||
@ -71,10 +73,11 @@ class CallbackInjectionPoint:
|
|||||||
|
|
||||||
|
|
||||||
class ExtensionsManager:
|
class ExtensionsManager:
|
||||||
def __init__(self):
|
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
|
||||||
self.extensions = []
|
self.extensions: List[ExtensionBase] = []
|
||||||
|
self._is_canceled = is_canceled
|
||||||
|
|
||||||
self._callbacks = {}
|
self._callbacks: Dict[str, CallbackInjectionPoint] = {}
|
||||||
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
|
||||||
|
|
||||||
def add_extension(self, ext: ExtensionBase):
|
def add_extension(self, ext: ExtensionBase):
|
||||||
@ -93,6 +96,11 @@ class ExtensionsManager:
|
|||||||
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
raise Exception(f"Unsupported injection type: {inj_info.type}")
|
||||||
|
|
||||||
def call_callback(self, name: str, *args, **kwargs):
|
def call_callback(self, name: str, *args, **kwargs):
|
||||||
|
# TODO: add to patchers too?
|
||||||
|
# and if so, should it be only in beginning of function or in for loop
|
||||||
|
if self._is_canceled and self._is_canceled():
|
||||||
|
raise CanceledException
|
||||||
|
|
||||||
if name in self._callbacks:
|
if name in self._callbacks:
|
||||||
self._callbacks[name](*args, **kwargs)
|
self._callbacks[name](*args, **kwargs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user