diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 6b314d10bf..994c99dc45 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,7 +1,7 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Optional from PIL.Image import Image from torch import Tensor @@ -364,10 +364,14 @@ class ConfigInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface): def __init__( - self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool] + self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event ) -> None: super().__init__(services, context_data) - self._is_canceled = is_canceled + self._cancel_event = cancel_event + + def is_canceled(self) -> bool: + """Checks if the current invocation has been canceled.""" + return self._cancel_event.is_set() def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: """ @@ -385,7 +389,7 @@ class UtilInterface(InvocationContextInterface): intermediate_state=intermediate_state, base_model=base_model, events=self._services.events, - is_canceled=self._is_canceled, + is_canceled=self.is_canceled, ) @@ -406,7 +410,6 @@ class InvocationContext: boards: BoardsInterface, context_data: InvocationContextData, services: InvocationServices, - is_canceled: Callable[[], bool], ) -> None: self.images = images """Provides methods to save, get and update images and their metadata.""" @@ -424,8 +427,6 @@ class InvocationContext: """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self.is_canceled = is_canceled - """Checks if the current invocation has been canceled.""" self._data = context_data """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services @@ -444,15 +445,12 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - def is_canceled() -> bool: - return cancel_event.is_set() - logger = LoggerInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data) tensors = TensorsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data) - util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled) + util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event) conditioning = ConditioningInterface(services=services, context_data=context_data) boards = BoardsInterface(services=services, context_data=context_data) @@ -467,7 +465,6 @@ def build_invocation_context( conditioning=conditioning, services=services, boards=boards, - is_canceled=is_canceled, ) return ctx