feat(processor): update enriched errors & fail_queue_item()

This commit is contained in:
psychedelicious 2024-05-23 15:20:22 +10:00
parent 6a34176376
commit db0ef8d316
2 changed files with 83 additions and 53 deletions

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from threading import Event
from types import TracebackType
from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -71,9 +70,9 @@ class OnNodeError(Protocol):
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
error_type: str,
error_message: str,
error_traceback: str,
) -> bool: ...
@ -88,8 +87,8 @@ class OnAfterRunSession(Protocol):
class OnNonFatalProcessorError(Protocol):
def __call__(
self,
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
queue_item: Optional[SessionQueueItem] = None,
queue_item: Optional[SessionQueueItem],
error_type: str,
error_message: str,
error_traceback: str,
) -> bool: ...

View File

@ -2,7 +2,6 @@ import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from types import TracebackType
from typing import Optional
from fastapi_events.handlers.local import local_handler
@ -30,12 +29,6 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, Se
from .session_processor_common import SessionProcessorStatus
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
"""Formats a stacktrace as a string"""
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
@ -71,10 +64,16 @@ class DefaultSessionRunner(SessionRunnerBase):
invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
traceback = e.__traceback__
assert traceback is not None
self._on_node_error(e.node, queue_item, type(e), e, traceback)
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=e.node,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
break
if invocation is None or self._cancel_event.is_set():
@ -126,10 +125,16 @@ class DefaultSessionRunner(SessionRunnerBase):
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_traceback = e.__traceback__
assert exc_traceback is not None
self._on_node_error(invocation, queue_item, type(e), e, exc_traceback)
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_node_error(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
# If profiling is enabled, start the profiler
@ -166,7 +171,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks:
callback(queue_item)
callback(queue_item=queue_item)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
@ -181,7 +186,7 @@ class DefaultSessionRunner(SessionRunnerBase):
)
for callback in self._on_before_run_node_callbacks:
callback(invocation, queue_item)
callback(invocation=invocation, queue_item=queue_item)
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
@ -199,23 +204,23 @@ class DefaultSessionRunner(SessionRunnerBase):
)
for callback in self._on_after_run_node_callbacks:
callback(invocation, queue_item, output)
callback(invocation=invocation, queue_item=queue_item, output=output)
def _on_node_error(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
error_type: str,
error_message: str,
error_traceback: str,
):
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
queue_item.session.set_node_error(invocation.id, stacktrace)
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, node_error)
self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}"
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
)
self._services.logger.error(stacktrace)
self._services.logger.error(error_traceback)
# Send error event
self._services.events.emit_invocation_error(
@ -225,14 +230,21 @@ class DefaultSessionRunner(SessionRunnerBase):
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=exc_type.__name__,
error=stacktrace,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
for callback in self._on_node_error_callbacks:
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
callback(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
class DefaultSessionProcessor(SessionProcessorBase):
@ -374,16 +386,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.session_runner.run(queue_item=self._queue_item)
except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_traceback = e.__traceback__
assert exc_traceback is not None
self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback)
# Immediately poll for next queue item
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(
queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval)
continue
except Exception:
except Exception as e:
# Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
pass
finally:
stop_event.clear()
@ -394,19 +415,29 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
error_type: str,
error_message: str,
error_traceback: str,
) -> None:
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}")
self._invoker.services.logger.error(stacktrace)
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
if queue_item is not None:
# Update the queue item with the completed session
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# And cancel the queue item with an error
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
# Fail the queue item
self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(exc_type, exc_value, exc_traceback, queue_item)
callback(
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)