refactor(nodes): move is_canceled to context.util

This commit is contained in:
psychedelicious 2024-02-18 11:54:16 +11:00
parent f31e4205aa
commit 39bdf5c4e9

View File

@ -1,7 +1,7 @@
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional from typing import TYPE_CHECKING, Optional
from PIL.Image import Image from PIL.Image import Image
from torch import Tensor from torch import Tensor
@ -364,10 +364,14 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface):
def __init__( def __init__(
self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool] self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event
) -> None: ) -> None:
super().__init__(services, context_data) super().__init__(services, context_data)
self._is_canceled = is_canceled self._cancel_event = cancel_event
def is_canceled(self) -> bool:
"""Checks if the current invocation has been canceled."""
return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None: def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
""" """
@ -385,7 +389,7 @@ class UtilInterface(InvocationContextInterface):
intermediate_state=intermediate_state, intermediate_state=intermediate_state,
base_model=base_model, base_model=base_model,
events=self._services.events, events=self._services.events,
is_canceled=self._is_canceled, is_canceled=self.is_canceled,
) )
@ -406,7 +410,6 @@ class InvocationContext:
boards: BoardsInterface, boards: BoardsInterface,
context_data: InvocationContextData, context_data: InvocationContextData,
services: InvocationServices, services: InvocationServices,
is_canceled: Callable[[], bool],
) -> None: ) -> None:
self.images = images self.images = images
"""Provides methods to save, get and update images and their metadata.""" """Provides methods to save, get and update images and their metadata."""
@ -424,8 +427,6 @@ class InvocationContext:
"""Provides utility methods.""" """Provides utility methods."""
self.boards = boards self.boards = boards
"""Provides methods to interact with boards.""" """Provides methods to interact with boards."""
self.is_canceled = is_canceled
"""Checks if the current invocation has been canceled."""
self._data = context_data self._data = context_data
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" """Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
self._services = services self._services = services
@ -444,15 +445,12 @@ def build_invocation_context(
:param invocation_context_data: The invocation context data. :param invocation_context_data: The invocation context data.
""" """
def is_canceled() -> bool:
return cancel_event.is_set()
logger = LoggerInterface(services=services, context_data=context_data) logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data) images = ImagesInterface(services=services, context_data=context_data)
tensors = TensorsInterface(services=services, context_data=context_data) tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data) models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data) config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled) util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event)
conditioning = ConditioningInterface(services=services, context_data=context_data) conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services, context_data=context_data) boards = BoardsInterface(services=services, context_data=context_data)
@ -467,7 +465,6 @@ def build_invocation_context(
conditioning=conditioning, conditioning=conditioning,
services=services, services=services,
boards=boards, boards=boards,
is_canceled=is_canceled,
) )
return ctx return ctx