mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(nodes): move is_canceled to context.util
This commit is contained in:
parent
18adcc1dd2
commit
fdac0c3c9b
@ -1,7 +1,7 @@
|
|||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Callable, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -364,10 +364,14 @@ class ConfigInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
class UtilInterface(InvocationContextInterface):
|
class UtilInterface(InvocationContextInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool]
|
self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(services, context_data)
|
super().__init__(services, context_data)
|
||||||
self._is_canceled = is_canceled
|
self._cancel_event = cancel_event
|
||||||
|
|
||||||
|
def is_canceled(self) -> bool:
|
||||||
|
"""Checks if the current invocation has been canceled."""
|
||||||
|
return self._cancel_event.is_set()
|
||||||
|
|
||||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||||
"""
|
"""
|
||||||
@ -385,7 +389,7 @@ class UtilInterface(InvocationContextInterface):
|
|||||||
intermediate_state=intermediate_state,
|
intermediate_state=intermediate_state,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
events=self._services.events,
|
events=self._services.events,
|
||||||
is_canceled=self._is_canceled,
|
is_canceled=self.is_canceled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -406,7 +410,6 @@ class InvocationContext:
|
|||||||
boards: BoardsInterface,
|
boards: BoardsInterface,
|
||||||
context_data: InvocationContextData,
|
context_data: InvocationContextData,
|
||||||
services: InvocationServices,
|
services: InvocationServices,
|
||||||
is_canceled: Callable[[], bool],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.images = images
|
self.images = images
|
||||||
"""Provides methods to save, get and update images and their metadata."""
|
"""Provides methods to save, get and update images and their metadata."""
|
||||||
@ -424,8 +427,6 @@ class InvocationContext:
|
|||||||
"""Provides utility methods."""
|
"""Provides utility methods."""
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
"""Provides methods to interact with boards."""
|
"""Provides methods to interact with boards."""
|
||||||
self.is_canceled = is_canceled
|
|
||||||
"""Checks if the current invocation has been canceled."""
|
|
||||||
self._data = context_data
|
self._data = context_data
|
||||||
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
||||||
self._services = services
|
self._services = services
|
||||||
@ -444,15 +445,12 @@ def build_invocation_context(
|
|||||||
:param invocation_context_data: The invocation context data.
|
:param invocation_context_data: The invocation context data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def is_canceled() -> bool:
|
|
||||||
return cancel_event.is_set()
|
|
||||||
|
|
||||||
logger = LoggerInterface(services=services, context_data=context_data)
|
logger = LoggerInterface(services=services, context_data=context_data)
|
||||||
images = ImagesInterface(services=services, context_data=context_data)
|
images = ImagesInterface(services=services, context_data=context_data)
|
||||||
tensors = TensorsInterface(services=services, context_data=context_data)
|
tensors = TensorsInterface(services=services, context_data=context_data)
|
||||||
models = ModelsInterface(services=services, context_data=context_data)
|
models = ModelsInterface(services=services, context_data=context_data)
|
||||||
config = ConfigInterface(services=services, context_data=context_data)
|
config = ConfigInterface(services=services, context_data=context_data)
|
||||||
util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled)
|
util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event)
|
||||||
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||||
boards = BoardsInterface(services=services, context_data=context_data)
|
boards = BoardsInterface(services=services, context_data=context_data)
|
||||||
|
|
||||||
@ -467,7 +465,6 @@ def build_invocation_context(
|
|||||||
conditioning=conditioning,
|
conditioning=conditioning,
|
||||||
services=services,
|
services=services,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
is_canceled=is_canceled,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
|
Loading…
Reference in New Issue
Block a user