feat(processor): update enriched errors & fail_queue_item()

This commit is contained in:
psychedelicious 2024-05-23 15:20:22 +10:00
parent a8492bd7e4
commit 2dd3a85ade
2 changed files with 83 additions and 53 deletions

View File

@ -1,6 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from threading import Event from threading import Event
from types import TracebackType
from typing import Optional, Protocol from typing import Optional, Protocol
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -71,9 +70,9 @@ class OnNodeError(Protocol):
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
exc_type: type, error_type: str,
exc_value: BaseException, error_message: str,
exc_traceback: TracebackType, error_traceback: str,
) -> bool: ... ) -> bool: ...
@ -88,8 +87,8 @@ class OnAfterRunSession(Protocol):
class OnNonFatalProcessorError(Protocol): class OnNonFatalProcessorError(Protocol):
def __call__( def __call__(
self, self,
exc_type: type, queue_item: Optional[SessionQueueItem],
exc_value: BaseException, error_type: str,
exc_traceback: TracebackType, error_message: str,
queue_item: Optional[SessionQueueItem] = None, error_traceback: str,
) -> bool: ... ) -> bool: ...

View File

@ -2,7 +2,6 @@ import traceback
from contextlib import suppress from contextlib import suppress
from threading import BoundedSemaphore, Thread from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent from threading import Event as ThreadEvent
from types import TracebackType
from typing import Optional from typing import Optional
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
@ -30,12 +29,6 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, Se
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
"""Formats a stacktrace as a string"""
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
class DefaultSessionRunner(SessionRunnerBase): class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations""" """Processes a single session's invocations"""
@ -71,10 +64,16 @@ class DefaultSessionRunner(SessionRunnerBase):
invocation = queue_item.session.next() invocation = queue_item.session.next()
# Anything other than a `NodeInputError` is handled as a processor error # Anything other than a `NodeInputError` is handled as a processor error
except NodeInputError as e: except NodeInputError as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope error_type = e.__class__.__name__
traceback = e.__traceback__ error_message = str(e)
assert traceback is not None error_traceback = traceback.format_exc()
self._on_node_error(e.node, queue_item, type(e), e, traceback) self._on_node_error(
invocation=e.node,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
break break
if invocation is None or self._cancel_event.is_set(): if invocation is None or self._cancel_event.is_set():
@ -126,10 +125,16 @@ class DefaultSessionRunner(SessionRunnerBase):
# loop go to its next iteration, and the cancel event will be handled correctly. # loop go to its next iteration, and the cancel event will be handled correctly.
pass pass
except Exception as e: except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope error_type = e.__class__.__name__
exc_traceback = e.__traceback__ error_message = str(e)
assert exc_traceback is not None error_traceback = traceback.format_exc()
self._on_node_error(invocation, queue_item, type(e), e, exc_traceback) self._on_node_error(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
# If profiling is enabled, start the profiler # If profiling is enabled, start the profiler
@ -166,7 +171,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._services.performance_statistics.reset_stats() self._services.performance_statistics.reset_stats()
for callback in self._on_after_run_session_callbacks: for callback in self._on_after_run_session_callbacks:
callback(queue_item) callback(queue_item=queue_item)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed""" """Run before a node is executed"""
@ -181,7 +186,7 @@ class DefaultSessionRunner(SessionRunnerBase):
) )
for callback in self._on_before_run_node_callbacks: for callback in self._on_before_run_node_callbacks:
callback(invocation, queue_item) callback(invocation=invocation, queue_item=queue_item)
def _on_after_run_node( def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
@ -199,23 +204,23 @@ class DefaultSessionRunner(SessionRunnerBase):
) )
for callback in self._on_after_run_node_callbacks: for callback in self._on_after_run_node_callbacks:
callback(invocation, queue_item, output) callback(invocation=invocation, queue_item=queue_item, output=output)
def _on_node_error( def _on_node_error(
self, self,
invocation: BaseInvocation, invocation: BaseInvocation,
queue_item: SessionQueueItem, queue_item: SessionQueueItem,
exc_type: type, error_type: str,
exc_value: BaseException, error_message: str,
exc_traceback: TracebackType, error_traceback: str,
): ):
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) # Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, stacktrace) queue_item.session.set_node_error(invocation.id, node_error)
self._services.logger.error( self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}" f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
) )
self._services.logger.error(stacktrace) self._services.logger.error(error_traceback)
# Send error event # Send error event
self._services.events.emit_invocation_error( self._services.events.emit_invocation_error(
@ -225,14 +230,21 @@ class DefaultSessionRunner(SessionRunnerBase):
graph_execution_state_id=queue_item.session.id, graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(), node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id], source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=exc_type.__name__, error_type=error_type,
error=stacktrace, error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None), user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None), project_id=getattr(queue_item, "project_id", None),
) )
for callback in self._on_node_error_callbacks: for callback in self._on_node_error_callbacks:
callback(invocation, queue_item, exc_type, exc_value, exc_traceback) callback(
invocation=invocation,
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
@ -374,16 +386,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.session_runner.run(queue_item=self._queue_item) self.session_runner.run(queue_item=self._queue_item)
except Exception as e: except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope error_type = e.__class__.__name__
exc_traceback = e.__traceback__ error_message = str(e)
assert exc_traceback is not None error_traceback = traceback.format_exc()
self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback) self._on_non_fatal_processor_error(
# Immediately poll for next queue item queue_item=self._queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
# Wait for next polling interval or event to try again
poll_now_event.wait(self._polling_interval) poll_now_event.wait(self._polling_interval)
continue continue
except Exception: except Exception as e:
# Fatal error in processor, log and pass - we're done here # Fatal error in processor, log and pass - we're done here
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") error_type = e.__class__.__name__
error_message = str(e)
error_traceback = traceback.format_exc()
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
pass pass
finally: finally:
stop_event.clear() stop_event.clear()
@ -394,19 +415,29 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _on_non_fatal_processor_error( def _on_non_fatal_processor_error(
self, self,
queue_item: Optional[SessionQueueItem], queue_item: Optional[SessionQueueItem],
exc_type: type, error_type: str,
exc_value: BaseException, error_message: str,
exc_traceback: TracebackType, error_traceback: str,
) -> None: ) -> None:
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
# Non-fatal error in processor # Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}") self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(stacktrace) self._invoker.services.logger.error(error_traceback)
if queue_item is not None: if queue_item is not None:
# Update the queue item with the completed session # Update the queue item with the completed session
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# And cancel the queue item with an error # Fail the queue item
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
for callback in self._on_non_fatal_processor_error_callbacks: for callback in self._on_non_fatal_processor_error_callbacks:
callback(exc_type, exc_value, exc_traceback, queue_item) callback(
queue_item=queue_item,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)