From bd7e515290b9ccf9b44398933e486df27e1c60a0 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Thu, 16 Mar 2023 20:05:36 -0700 Subject: [PATCH] [nodes] Add cancelation to the API --- invokeai/app/api/routers/sessions.py | 15 ++++++++++ invokeai/app/services/invocation_queue.py | 34 ++++++++++++++++++++++- invokeai/app/services/invoker.py | 4 +++ invokeai/app/services/processor.py | 12 ++++++++ 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index 67e3c840c0..dc8fa03fc4 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -270,3 +270,18 @@ async def invoke_session( ApiDependencies.invoker.invoke(session, invoke_all=all) return Response(status_code=202) + + +@session_router.delete( + "/{session_id}/invoke", + operation_id="cancel_session_invoke", + responses={ + 202: {"description": "The invocation is canceled"} + }, +) +async def cancel_session_invoke( + session_id: str = Path(description="The id of the session to cancel"), +) -> None: + """Invokes a session""" + ApiDependencies.invoker.cancel(session_id) + return Response(status_code=202) diff --git a/invokeai/app/services/invocation_queue.py b/invokeai/app/services/invocation_queue.py index 88a4f8708d..4a42789b12 100644 --- a/invokeai/app/services/invocation_queue.py +++ b/invokeai/app/services/invocation_queue.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from queue import Queue +import time # TODO: make this serializable @@ -10,6 +11,7 @@ class InvocationQueueItem: graph_execution_state_id: str invocation_id: str invoke_all: bool + timestamp: float def __init__( self, @@ -22,6 +24,7 @@ class InvocationQueueItem: self.graph_execution_state_id = graph_execution_state_id self.invocation_id = invocation_id self.invoke_all = invoke_all + self.timestamp = time.time() class InvocationQueueABC(ABC): @@ -35,15 +38,44 @@ class InvocationQueueABC(ABC): def put(self, item: InvocationQueueItem | None) -> None: pass + @abstractmethod + def cancel(self, graph_execution_state_id: str) -> None: + pass + + @abstractmethod + def is_canceled(self, graph_execution_state_id: str) -> bool: + pass + class MemoryInvocationQueue(InvocationQueueABC): __queue: Queue + __cancellations: dict[str, float] def __init__(self): self.__queue = Queue() + self.__cancellations = dict() def get(self) -> InvocationQueueItem: - return self.__queue.get() + item = self.__queue.get() + + while isinstance(item, InvocationQueueItem) \ + and item.graph_execution_state_id in self.__cancellations \ + and self.__cancellations[item.graph_execution_state_id] > item.timestamp: + item = self.__queue.get() + + # Clear old items + for graph_execution_state_id in list(self.__cancellations.keys()): + if self.__cancellations[graph_execution_state_id] < item.timestamp: + del self.__cancellations[graph_execution_state_id] + + return item def put(self, item: InvocationQueueItem | None) -> None: self.__queue.put(item) + + def cancel(self, graph_execution_state_id: str) -> None: + if graph_execution_state_id not in self.__cancellations: + self.__cancellations[graph_execution_state_id] = time.time() + + def is_canceled(self, graph_execution_state_id: str) -> bool: + return graph_execution_state_id in self.__cancellations diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index f234cd827b..594477ed0f 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -50,6 +50,10 @@ class Invoker: new_state = GraphExecutionState(graph=Graph() if graph is None else graph) self.services.graph_execution_manager.set(new_state) return new_state + + def cancel(self, graph_execution_state_id: str) -> None: + """Cancels the given execution state""" + self.services.queue.cancel(graph_execution_state_id) def __start_service(self, service) -> None: # Call start() method on any services that have it diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 5baa64503c..e86da265f1 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -58,6 +58,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): ) ) + # Check queue to see if this is canceled, and skip if so + if self.__invoker.services.queue.is_canceled( + graph_execution_state.id + ): + continue + # Save outputs and history graph_execution_state.complete(invocation.id, outputs) @@ -95,6 +101,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC): ) pass + + # Check queue to see if this is canceled, and skip if so + if self.__invoker.services.queue.is_canceled( + graph_execution_state.id + ): + continue # Queue any further commands if invoking all is_complete = graph_execution_state.is_complete()