mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(processor): update enriched errors & fail_queue_item()
This commit is contained in:
parent
6a34176376
commit
db0ef8d316
@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
from types import TracebackType
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
@ -71,9 +70,9 @@ class OnNodeError(Protocol):
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
queue_item: SessionQueueItem,
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> bool: ...
|
||||
|
||||
|
||||
@ -88,8 +87,8 @@ class OnAfterRunSession(Protocol):
|
||||
class OnNonFatalProcessorError(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
queue_item: Optional[SessionQueueItem] = None,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> bool: ...
|
||||
|
@ -2,7 +2,6 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from types import TracebackType
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
@ -30,12 +29,6 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, Se
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
||||
"""Formats a stacktrace as a string"""
|
||||
|
||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
@ -71,10 +64,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
invocation = queue_item.session.next()
|
||||
# Anything other than a `NodeInputError` is handled as a processor error
|
||||
except NodeInputError as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
traceback = e.__traceback__
|
||||
assert traceback is not None
|
||||
self._on_node_error(e.node, queue_item, type(e), e, traceback)
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_node_error(
|
||||
invocation=e.node,
|
||||
queue_item=queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
break
|
||||
|
||||
if invocation is None or self._cancel_event.is_set():
|
||||
@ -126,10 +125,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
exc_traceback = e.__traceback__
|
||||
assert exc_traceback is not None
|
||||
self._on_node_error(invocation, queue_item, type(e), e, exc_traceback)
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_node_error(
|
||||
invocation=invocation,
|
||||
queue_item=queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If profiling is enabled, start the profiler
|
||||
@ -166,7 +171,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._services.performance_statistics.reset_stats()
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item)
|
||||
callback(queue_item=queue_item)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
@ -181,7 +186,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
|
||||
for callback in self._on_before_run_node_callbacks:
|
||||
callback(invocation, queue_item)
|
||||
callback(invocation=invocation, queue_item=queue_item)
|
||||
|
||||
def _on_after_run_node(
|
||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||
@ -199,23 +204,23 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
|
||||
for callback in self._on_after_run_node_callbacks:
|
||||
callback(invocation, queue_item, output)
|
||||
callback(invocation=invocation, queue_item=queue_item, output=output)
|
||||
|
||||
def _on_node_error(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
queue_item: SessionQueueItem,
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
):
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
|
||||
queue_item.session.set_node_error(invocation.id, stacktrace)
|
||||
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
|
||||
node_error = f"{error_type}: {error_message}"
|
||||
queue_item.session.set_node_error(invocation.id, node_error)
|
||||
self._services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}"
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
|
||||
)
|
||||
self._services.logger.error(stacktrace)
|
||||
self._services.logger.error(error_traceback)
|
||||
|
||||
# Send error event
|
||||
self._services.events.emit_invocation_error(
|
||||
@ -225,14 +230,21 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=exc_type.__name__,
|
||||
error=stacktrace,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
user_id=getattr(queue_item, "user_id", None),
|
||||
project_id=getattr(queue_item, "project_id", None),
|
||||
)
|
||||
|
||||
for callback in self._on_node_error_callbacks:
|
||||
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
callback(
|
||||
invocation=invocation,
|
||||
queue_item=queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
@ -374,16 +386,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
exc_traceback = e.__traceback__
|
||||
assert exc_traceback is not None
|
||||
self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback)
|
||||
# Immediately poll for next queue item
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_non_fatal_processor_error(
|
||||
queue_item=self._queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
# Wait for next polling interval or event to try again
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
# Fatal error in processor, log and pass - we're done here
|
||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
|
||||
self._invoker.services.logger.error(error_traceback)
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
@ -394,19 +415,29 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> None:
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}")
|
||||
self._invoker.services.logger.error(stacktrace)
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
|
||||
self._invoker.services.logger.error(error_traceback)
|
||||
|
||||
if queue_item is not None:
|
||||
# Update the queue item with the completed session
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
# And cancel the queue item with an error
|
||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||
# Fail the queue item
|
||||
self._invoker.services.session_queue.fail_queue_item(
|
||||
item_id=queue_item.item_id,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
|
||||
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||
callback(exc_type, exc_value, exc_traceback, queue_item)
|
||||
callback(
|
||||
queue_item=queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user