From d6696a7b9793a5c9dd155b9baadfc8fa334d0b5f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 23 May 2024 15:15:53 +1000 Subject: [PATCH] feat(queue): session queue error handling - Add handling for new error columns `error_type`, `error_message`, `error_traceback`. - Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility. - Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args. - --- .../session_queue/session_queue_base.py | 7 ++- .../session_queue/session_queue_common.py | 19 ++++++- .../session_queue/session_queue_sqlite.py | 56 ++++++++++++++++--- 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index f46463f528..8b21998f19 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -74,10 +74,15 @@ class SessionQueueBase(ABC): pass @abstractmethod - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: """Cancels a session queue item""" pass + @abstractmethod + def fail_queue_item(self, item_id: int, error_type: str, error_message: str, error_traceback: str) -> SessionQueueItem: + """Fails a session queue item""" + pass + @abstractmethod def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: """Cancels all queue items with matching batch IDs""" diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 94db6999c2..7f4601eba7 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -3,7 +3,16 @@ import json from itertools import chain, product from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast -from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator +from pydantic import ( + AliasChoices, + BaseModel, + ConfigDict, + Field, + StrictStr, + TypeAdapter, + field_validator, + model_validator, +) from pydantic_core import to_jsonable_python from invokeai.app.invocations.baseinvocation import BaseInvocation @@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel): session_id: str = Field( description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." ) - error: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored") + error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored") + error_traceback: Optional[str] = Field( + default=None, + description="The error traceback if this queue item errored", + validation_alias=AliasChoices("error_traceback", "error"), + ) created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 87c22c496f..dfd00a7809 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -82,10 +82,18 @@ class SqliteSessionQueue(SessionQueueBase): async def _handle_error_event(self, event: FastAPIEvent) -> None: try: item_id = event[1]["data"]["queue_item_id"] - error = event[1]["data"]["error"] + error_type = event[1]["data"]["error_type"] + error_message = event[1]["data"]["error_message"] + error_traceback = event[1]["data"]["error_traceback"] queue_item = self.get_queue_item(item_id) # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) + queue_item = self._set_queue_item_status( + item_id=queue_item.item_id, + status="failed", + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) except SessionQueueItemNotFoundError: return @@ -272,17 +280,22 @@ class SqliteSessionQueue(SessionQueueBase): return SessionQueueItem.queue_item_from_dict(dict(result)) def _set_queue_item_status( - self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None + self, + item_id: int, + status: QUEUE_ITEM_STATUS, + error_type: Optional[str] = None, + error_message: Optional[str] = None, + error_traceback: Optional[str] = None, ) -> SessionQueueItem: try: self.__lock.acquire() self.__cursor.execute( """--sql UPDATE session_queue - SET status = ?, error = ? + SET status = ?, error_type = ?, error_message = ?, error_traceback = ? WHERE item_id = ? """, - (status, error, item_id), + (status, error_type, error_message, error_traceback, item_id), ) self.__conn.commit() except Exception: @@ -425,11 +438,34 @@ class SqliteSessionQueue(SessionQueueBase): self.__lock.release() return PruneResult(deleted=count) - def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: + def cancel_queue_item(self, item_id: int) -> SessionQueueItem: queue_item = self.get_queue_item(item_id) if queue_item.status not in ["canceled", "failed", "completed"]: - status = "failed" if error is not None else "canceled" - queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here + queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") + self.__invoker.services.events.emit_session_canceled( + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + queue_batch_id=queue_item.batch_id, + graph_execution_state_id=queue_item.session_id, + ) + return queue_item + + def fail_queue_item( + self, + item_id: int, + error_type: str, + error_message: str, + error_traceback: str, + ) -> SessionQueueItem: + queue_item = self.get_queue_item(item_id) + if queue_item.status not in ["canceled", "failed", "completed"]: + queue_item = self._set_queue_item_status( + item_id=item_id, + status="failed", + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) self.__invoker.services.events.emit_session_canceled( queue_item_id=queue_item.item_id, queue_id=queue_item.queue_id, @@ -602,7 +638,9 @@ class SqliteSessionQueue(SessionQueueBase): status, priority, field_values, - error, + error_type, + error_message, + error_traceback, created_at, updated_at, completed_at,