mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): update queue item's session on session completion
The session is never updated in the queue after it is first enqueued. As a result, the queue detail view in the frontend never never updates and the session itself doesn't show outputs, execution graph, etc. We need a new method on the queue service to update a queue item's session, then call it before updating the queue item's status. Queue item status may be updated via a session-type event _or_ queue-type event. Adding the updated session to all these events is a hairy - simpler to just update the session before we do anything that could trigger a queue item status change event: - Before calling `emit_session_complete` in the processor (handles session error, completed and cancel events and the corresponding queue events) - Before calling `cancel_queue_item` in the processor (handles another way queue items can be canceled, outside the session execution loop) When serializing the session, both in the new service method and the `get_queue_item` endpoint, we need to use `exclude_none=True` to prevent unexpected validation errors.
This commit is contained in:
parent
242b2a0b59
commit
12ce095bb2
@ -203,6 +203,7 @@ async def get_batch_status(
|
|||||||
responses={
|
responses={
|
||||||
200: {"model": SessionQueueItem},
|
200: {"model": SessionQueueItem},
|
||||||
},
|
},
|
||||||
|
response_model_exclude_none=True,
|
||||||
)
|
)
|
||||||
async def get_queue_item(
|
async def get_queue_item(
|
||||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
@ -222,6 +222,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
|
|
||||||
# The session is complete if the all invocations are complete or there was an error
|
# The session is complete if the all invocations are complete or there was an error
|
||||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||||
|
self._invoker.services.session_queue.set_queue_item_session(
|
||||||
|
self._queue_item.item_id, self._queue_item.session
|
||||||
|
)
|
||||||
self._invoker.services.events.emit_session_complete(self._queue_item)
|
self._invoker.services.events.emit_session_complete(self._queue_item)
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
if self._profiler:
|
if self._profiler:
|
||||||
@ -253,6 +256,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
)
|
)
|
||||||
# Cancel the queue item
|
# Cancel the queue item
|
||||||
if self._queue_item is not None:
|
if self._queue_item is not None:
|
||||||
|
self._invoker.services.session_queue.set_queue_item_session(
|
||||||
|
self._queue_item.item_id, self._queue_item.session
|
||||||
|
)
|
||||||
self._invoker.services.session_queue.cancel_queue_item(
|
self._invoker.services.session_queue.cancel_queue_item(
|
||||||
self._queue_item.item_id, error=traceback.format_exc()
|
self._queue_item.item_id, error=traceback.format_exc()
|
||||||
)
|
)
|
||||||
|
@ -16,6 +16,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
@ -103,3 +104,8 @@ class SessionQueueBase(ABC):
|
|||||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||||
"""Gets a session queue item by ID"""
|
"""Gets a session queue item by ID"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||||
|
"""Sets the session for a session queue item. Use this to update the session state."""
|
||||||
|
pass
|
||||||
|
@ -30,6 +30,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
calc_session_count,
|
calc_session_count,
|
||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
|
|
||||||
@ -530,6 +531,29 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||||
|
|
||||||
|
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||||
|
try:
|
||||||
|
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||||
|
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||||
|
# during execution.
|
||||||
|
session_json = session.model_dump_json(warnings=False, exclude_none=True)
|
||||||
|
self.__lock.acquire()
|
||||||
|
self.__cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
UPDATE session_queue
|
||||||
|
SET session = ?
|
||||||
|
WHERE item_id = ?
|
||||||
|
""",
|
||||||
|
(session_json, item_id),
|
||||||
|
)
|
||||||
|
self.__conn.commit()
|
||||||
|
except Exception:
|
||||||
|
self.__conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
|
return self.get_queue_item(item_id)
|
||||||
|
|
||||||
def list_queue_items(
|
def list_queue_items(
|
||||||
self,
|
self,
|
||||||
queue_id: str,
|
queue_id: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user