mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(nodes): move is_canceled to context.util
This commit is contained in:
parent
18adcc1dd2
commit
fdac0c3c9b
@ -1,7 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@ -364,10 +364,14 @@ class ConfigInterface(InvocationContextInterface):
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def __init__(
|
||||
self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool]
|
||||
self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event
|
||||
) -> None:
|
||||
super().__init__(services, context_data)
|
||||
self._is_canceled = is_canceled
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current invocation has been canceled."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
@ -385,7 +389,7 @@ class UtilInterface(InvocationContextInterface):
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=base_model,
|
||||
events=self._services.events,
|
||||
is_canceled=self._is_canceled,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
@ -406,7 +410,6 @@ class InvocationContext:
|
||||
boards: BoardsInterface,
|
||||
context_data: InvocationContextData,
|
||||
services: InvocationServices,
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
self.images = images
|
||||
"""Provides methods to save, get and update images and their metadata."""
|
||||
@ -424,8 +427,6 @@ class InvocationContext:
|
||||
"""Provides utility methods."""
|
||||
self.boards = boards
|
||||
"""Provides methods to interact with boards."""
|
||||
self.is_canceled = is_canceled
|
||||
"""Checks if the current invocation has been canceled."""
|
||||
self._data = context_data
|
||||
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
||||
self._services = services
|
||||
@ -444,15 +445,12 @@ def build_invocation_context(
|
||||
:param invocation_context_data: The invocation context data.
|
||||
"""
|
||||
|
||||
def is_canceled() -> bool:
|
||||
return cancel_event.is_set()
|
||||
|
||||
logger = LoggerInterface(services=services, context_data=context_data)
|
||||
images = ImagesInterface(services=services, context_data=context_data)
|
||||
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||
models = ModelsInterface(services=services, context_data=context_data)
|
||||
config = ConfigInterface(services=services, context_data=context_data)
|
||||
util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled)
|
||||
util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event)
|
||||
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||
boards = BoardsInterface(services=services, context_data=context_data)
|
||||
|
||||
@ -467,7 +465,6 @@ def build_invocation_context(
|
||||
conditioning=conditioning,
|
||||
services=services,
|
||||
boards=boards,
|
||||
is_canceled=is_canceled,
|
||||
)
|
||||
|
||||
return ctx
|
||||
|
Loading…
Reference in New Issue
Block a user