refactor(nodes): move is_canceled to context.util

This commit is contained in:
psychedelicious 2024-02-18 11:54:16 +11:00
parent f31e4205aa
commit 39bdf5c4e9

View File

@ -1,7 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image
from torch import Tensor
@ -364,10 +364,14 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface):
def __init__(
self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool]
self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event
) -> None:
super().__init__(services, context_data)
self._is_canceled = is_canceled
self._cancel_event = cancel_event
def is_canceled(self) -> bool:
"""Checks if the current invocation has been canceled."""
return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
@ -385,7 +389,7 @@ class UtilInterface(InvocationContextInterface):
intermediate_state=intermediate_state,
base_model=base_model,
events=self._services.events,
is_canceled=self._is_canceled,
is_canceled=self.is_canceled,
)
@ -406,7 +410,6 @@ class InvocationContext:
boards: BoardsInterface,
context_data: InvocationContextData,
services: InvocationServices,
is_canceled: Callable[[], bool],
) -> None:
self.images = images
"""Provides methods to save, get and update images and their metadata."""
@ -424,8 +427,6 @@ class InvocationContext:
"""Provides utility methods."""
self.boards = boards
"""Provides methods to interact with boards."""
self.is_canceled = is_canceled
"""Checks if the current invocation has been canceled."""
self._data = context_data
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
self._services = services
@ -444,15 +445,12 @@ def build_invocation_context(
:param invocation_context_data: The invocation context data.
"""
def is_canceled() -> bool:
return cancel_event.is_set()
logger = LoggerInterface(services=services, context_data=context_data)
images = ImagesInterface(services=services, context_data=context_data)
tensors = TensorsInterface(services=services, context_data=context_data)
models = ModelsInterface(services=services, context_data=context_data)
config = ConfigInterface(services=services, context_data=context_data)
util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled)
util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event)
conditioning = ConditioningInterface(services=services, context_data=context_data)
boards = BoardsInterface(services=services, context_data=context_data)
@ -467,7 +465,6 @@ def build_invocation_context(
conditioning=conditioning,
services=services,
boards=boards,
is_canceled=is_canceled,
)
return ctx