From 93e4c3dbc2a9305836deabde41098085c63e3c9c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 08:49:41 +1000 Subject: [PATCH] feat(app): update queue item's session on session completion The session is never updated in the queue after it is first enqueued. As a result, the queue detail view in the frontend never never updates and the session itself doesn't show outputs, execution graph, etc. We need a new method on the queue service to update a queue item's session, then call it before updating the queue item's status. Queue item status may be updated via a session-type event _or_ queue-type event. Adding the updated session to all these events is a hairy - simpler to just update the session before we do anything that could trigger a queue item status change event: - Before calling `emit_session_complete` in the processor (handles session error, completed and cancel events and the corresponding queue events) - Before calling `cancel_queue_item` in the processor (handles another way queue items can be canceled, outside the session execution loop) When serializing the session, both in the new service method and the `get_queue_item` endpoint, we need to use `exclude_none=True` to prevent unexpected validation errors. --- invokeai/app/api/routers/session_queue.py | 1 + .../session_processor_default.py | 6 +++++ .../session_queue/session_queue_base.py | 6 +++++ .../session_queue/session_queue_sqlite.py | 24 +++++++++++++++++++ 4 files changed, 37 insertions(+) diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 40f1f2213b..7161e54a41 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -203,6 +203,7 @@ async def get_batch_status( responses={ 200: {"model": SessionQueueItem}, }, + response_model_exclude_none=True, ) async def get_queue_item( queue_id: str = Path(description="The queue id to perform this operation on"), diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 894996b1e6..2a0ebc3168 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -245,6 +245,9 @@ class DefaultSessionProcessor(SessionProcessorBase): # The session is complete if the all invocations are complete or there was an error if self._queue_item.session.is_complete() or cancel_event.is_set(): # Send complete event + self._invoker.services.session_queue.set_queue_item_session( + self._queue_item.item_id, self._queue_item.session + ) self._invoker.services.events.emit_graph_execution_complete( queue_batch_id=self._queue_item.batch_id, queue_item_id=self._queue_item.item_id, @@ -281,6 +284,9 @@ class DefaultSessionProcessor(SessionProcessorBase): ) # Cancel the queue item if self._queue_item is not None: + self._invoker.services.session_queue.set_queue_item_session( + self._queue_item.item_id, self._queue_item.session + ) self._invoker.services.session_queue.cancel_queue_item( self._queue_item.item_id, error=traceback.format_exc() ) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index e0b6e4f528..f46463f528 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItemDTO, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults @@ -103,3 +104,8 @@ class SessionQueueBase(ABC): def get_queue_item(self, item_id: int) -> SessionQueueItem: """Gets a session queue item by ID""" pass + + @abstractmethod + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + """Sets the session for a session queue item. Use this to update the session state.""" + pass diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index ffcd7c40ca..87c22c496f 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -27,6 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( calc_session_count, prepare_values_to_insert, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -562,6 +563,29 @@ class SqliteSessionQueue(SessionQueueBase): raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") return SessionQueueItem.queue_item_from_dict(dict(result)) + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + try: + # Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors + # when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced + # during execution. + session_json = session.model_dump_json(warnings=False, exclude_none=True) + self.__lock.acquire() + self.__cursor.execute( + """--sql + UPDATE session_queue + SET session = ? + WHERE item_id = ? + """, + (session_json, item_id), + ) + self.__conn.commit() + except Exception: + self.__conn.rollback() + raise + finally: + self.__lock.release() + return self.get_queue_item(item_id) + def list_queue_items( self, queue_id: str,