mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Add cancelation to the API
This commit is contained in:
parent
27a113d872
commit
510ae34bff
@ -270,3 +270,18 @@ async def invoke_session(
|
||||
|
||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||
return Response(status_code=202)
|
||||
|
||||
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={
|
||||
202: {"description": "The invocation is canceled"}
|
||||
},
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
) -> None:
|
||||
"""Invokes a session"""
|
||||
ApiDependencies.invoker.cancel(session_id)
|
||||
return Response(status_code=202)
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
@ -10,6 +11,7 @@ class InvocationQueueItem:
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -22,6 +24,7 @@ class InvocationQueueItem:
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
@ -35,15 +38,44 @@ class InvocationQueueABC(ABC):
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
__cancellations: dict[str, float]
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
self.__cancellations = dict()
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
return self.__queue.get()
|
||||
item = self.__queue.get()
|
||||
|
||||
while isinstance(item, InvocationQueueItem) \
|
||||
and item.graph_execution_state_id in self.__cancellations \
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||
del self.__cancellations[graph_execution_state_id]
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: InvocationQueueItem | None) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
if graph_execution_state_id not in self.__cancellations:
|
||||
self.__cancellations[graph_execution_state_id] = time.time()
|
||||
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
return graph_execution_state_id in self.__cancellations
|
||||
|
@ -50,6 +50,10 @@ class Invoker:
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
|
@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
@ -95,6 +101,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(
|
||||
graph_execution_state.id
|
||||
):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
|
Loading…
Reference in New Issue
Block a user