feat(queue): session queue error handling

- Add handling for new error columns `error_type`, `error_message`, `error_traceback`.
- Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility.
- Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args.
-
This commit is contained in:
psychedelicious 2024-05-23 15:15:53 +10:00
parent 887b73aece
commit 25954ea750
3 changed files with 70 additions and 12 deletions

View File

@ -74,10 +74,15 @@ class SessionQueueBase(ABC):
pass pass
@abstractmethod @abstractmethod
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item""" """Cancels a session queue item"""
pass pass
@abstractmethod
def fail_queue_item(self, item_id: int, error_type: str, error_message: str, error_traceback: str) -> SessionQueueItem:
"""Fails a session queue item"""
pass
@abstractmethod @abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs""" """Cancels all queue items with matching batch IDs"""

View File

@ -3,7 +3,16 @@ import json
from itertools import chain, product from itertools import chain, product
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
StrictStr,
TypeAdapter,
field_validator,
model_validator,
)
from pydantic_core import to_jsonable_python from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
session_id: str = Field( session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
) )
error: Optional[str] = Field(default=None, description="The error message if this queue item errored") error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
error_traceback: Optional[str] = Field(
default=None,
description="The error traceback if this queue item errored",
validation_alias=AliasChoices("error_traceback", "error"),
)
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created") created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated") updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started") started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")

View File

@ -82,10 +82,18 @@ class SqliteSessionQueue(SessionQueueBase):
async def _handle_error_event(self, event: FastAPIEvent) -> None: async def _handle_error_event(self, event: FastAPIEvent) -> None:
try: try:
item_id = event[1]["data"]["queue_item_id"] item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"] error_type = event[1]["data"]["error_type"]
error_message = event[1]["data"]["error_message"]
error_traceback = event[1]["data"]["error_traceback"]
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled # always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) queue_item = self._set_queue_item_status(
item_id=queue_item.item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
except SessionQueueItemNotFoundError: except SessionQueueItemNotFoundError:
return return
@ -272,17 +280,22 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result)) return SessionQueueItem.queue_item_from_dict(dict(result))
def _set_queue_item_status( def _set_queue_item_status(
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None self,
item_id: int,
status: QUEUE_ITEM_STATUS,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem: ) -> SessionQueueItem:
try: try:
self.__lock.acquire() self.__lock.acquire()
self.__cursor.execute( self.__cursor.execute(
"""--sql """--sql
UPDATE session_queue UPDATE session_queue
SET status = ?, error = ? SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
WHERE item_id = ? WHERE item_id = ?
""", """,
(status, error, item_id), (status, error_type, error_message, error_traceback, item_id),
) )
self.__conn.commit() self.__conn.commit()
except Exception: except Exception:
@ -425,11 +438,34 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release() self.__lock.release()
return PruneResult(deleted=count) return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem: def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id) queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]: if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled" queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item
def fail_queue_item(
self,
item_id: int,
error_type: str,
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
self.__invoker.services.events.emit_session_canceled( self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -602,7 +638,9 @@ class SqliteSessionQueue(SessionQueueBase):
status, status,
priority, priority,
field_values, field_values,
error, error_type,
error_message,
error_traceback,
created_at, created_at,
updated_at, updated_at,
completed_at, completed_at,