diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 40f1f2213b..7161e54a41 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -203,6 +203,7 @@ async def get_batch_status( responses={ 200: {"model": SessionQueueItem}, }, + response_model_exclude_none=True, ) async def get_queue_item( queue_id: str = Path(description="The queue id to perform this operation on"), diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 894996b1e6..2a0ebc3168 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -245,6 +245,9 @@ class DefaultSessionProcessor(SessionProcessorBase): # The session is complete if the all invocations are complete or there was an error if self._queue_item.session.is_complete() or cancel_event.is_set(): # Send complete event + self._invoker.services.session_queue.set_queue_item_session( + self._queue_item.item_id, self._queue_item.session + ) self._invoker.services.events.emit_graph_execution_complete( queue_batch_id=self._queue_item.batch_id, queue_item_id=self._queue_item.item_id, @@ -281,6 +284,9 @@ class DefaultSessionProcessor(SessionProcessorBase): ) # Cancel the queue item if self._queue_item is not None: + self._invoker.services.session_queue.set_queue_item_session( + self._queue_item.item_id, self._queue_item.session + ) self._invoker.services.session_queue.cancel_queue_item( self._queue_item.item_id, error=traceback.format_exc() ) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index e0b6e4f528..f46463f528 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( SessionQueueItemDTO, SessionQueueStatus, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults @@ -103,3 +104,8 @@ class SessionQueueBase(ABC): def get_queue_item(self, item_id: int) -> SessionQueueItem: """Gets a session queue item by ID""" pass + + @abstractmethod + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + """Sets the session for a session queue item. Use this to update the session state.""" + pass diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index ffcd7c40ca..87c22c496f 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -27,6 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import ( calc_session_count, prepare_values_to_insert, ) +from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase @@ -562,6 +563,29 @@ class SqliteSessionQueue(SessionQueueBase): raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") return SessionQueueItem.queue_item_from_dict(dict(result)) + def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: + try: + # Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors + # when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced + # during execution. + session_json = session.model_dump_json(warnings=False, exclude_none=True) + self.__lock.acquire() + self.__cursor.execute( + """--sql + UPDATE session_queue + SET session = ? + WHERE item_id = ? + """, + (session_json, item_id), + ) + self.__conn.commit() + except Exception: + self.__conn.rollback() + raise + finally: + self.__lock.release() + return self.get_queue_item(item_id) + def list_queue_items( self, queue_id: str,