mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(processor): update enriched errors & fail_queue_item()
This commit is contained in:
parent
a8492bd7e4
commit
2dd3a85ade
@ -1,6 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from threading import Event
|
from threading import Event
|
||||||
from types import TracebackType
|
|
||||||
from typing import Optional, Protocol
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
@ -71,9 +70,9 @@ class OnNodeError(Protocol):
|
|||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
exc_type: type,
|
error_type: str,
|
||||||
exc_value: BaseException,
|
error_message: str,
|
||||||
exc_traceback: TracebackType,
|
error_traceback: str,
|
||||||
) -> bool: ...
|
) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
@ -88,8 +87,8 @@ class OnAfterRunSession(Protocol):
|
|||||||
class OnNonFatalProcessorError(Protocol):
|
class OnNonFatalProcessorError(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
exc_type: type,
|
queue_item: Optional[SessionQueueItem],
|
||||||
exc_value: BaseException,
|
error_type: str,
|
||||||
exc_traceback: TracebackType,
|
error_message: str,
|
||||||
queue_item: Optional[SessionQueueItem] = None,
|
error_traceback: str,
|
||||||
) -> bool: ...
|
) -> bool: ...
|
||||||
|
@ -2,7 +2,6 @@ import traceback
|
|||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from threading import BoundedSemaphore, Thread
|
from threading import BoundedSemaphore, Thread
|
||||||
from threading import Event as ThreadEvent
|
from threading import Event as ThreadEvent
|
||||||
from types import TracebackType
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
@ -30,12 +29,6 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, Se
|
|||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
|
||||||
"""Formats a stacktrace as a string"""
|
|
||||||
|
|
||||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionRunner(SessionRunnerBase):
|
class DefaultSessionRunner(SessionRunnerBase):
|
||||||
"""Processes a single session's invocations"""
|
"""Processes a single session's invocations"""
|
||||||
|
|
||||||
@ -71,10 +64,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
invocation = queue_item.session.next()
|
invocation = queue_item.session.next()
|
||||||
# Anything other than a `NodeInputError` is handled as a processor error
|
# Anything other than a `NodeInputError` is handled as a processor error
|
||||||
except NodeInputError as e:
|
except NodeInputError as e:
|
||||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
error_type = e.__class__.__name__
|
||||||
traceback = e.__traceback__
|
error_message = str(e)
|
||||||
assert traceback is not None
|
error_traceback = traceback.format_exc()
|
||||||
self._on_node_error(e.node, queue_item, type(e), e, traceback)
|
self._on_node_error(
|
||||||
|
invocation=e.node,
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
if invocation is None or self._cancel_event.is_set():
|
if invocation is None or self._cancel_event.is_set():
|
||||||
@ -126,10 +125,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
error_type = e.__class__.__name__
|
||||||
exc_traceback = e.__traceback__
|
error_message = str(e)
|
||||||
assert exc_traceback is not None
|
error_traceback = traceback.format_exc()
|
||||||
self._on_node_error(invocation, queue_item, type(e), e, exc_traceback)
|
self._on_node_error(
|
||||||
|
invocation=invocation,
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|
||||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
@ -166,7 +171,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self._services.performance_statistics.reset_stats()
|
self._services.performance_statistics.reset_stats()
|
||||||
|
|
||||||
for callback in self._on_after_run_session_callbacks:
|
for callback in self._on_after_run_session_callbacks:
|
||||||
callback(queue_item)
|
callback(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run before a node is executed"""
|
"""Run before a node is executed"""
|
||||||
@ -181,7 +186,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for callback in self._on_before_run_node_callbacks:
|
for callback in self._on_before_run_node_callbacks:
|
||||||
callback(invocation, queue_item)
|
callback(invocation=invocation, queue_item=queue_item)
|
||||||
|
|
||||||
def _on_after_run_node(
|
def _on_after_run_node(
|
||||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||||
@ -199,23 +204,23 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for callback in self._on_after_run_node_callbacks:
|
for callback in self._on_after_run_node_callbacks:
|
||||||
callback(invocation, queue_item, output)
|
callback(invocation=invocation, queue_item=queue_item, output=output)
|
||||||
|
|
||||||
def _on_node_error(
|
def _on_node_error(
|
||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
queue_item: SessionQueueItem,
|
queue_item: SessionQueueItem,
|
||||||
exc_type: type,
|
error_type: str,
|
||||||
exc_value: BaseException,
|
error_message: str,
|
||||||
exc_traceback: TracebackType,
|
error_traceback: str,
|
||||||
):
|
):
|
||||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
|
||||||
|
node_error = f"{error_type}: {error_message}"
|
||||||
queue_item.session.set_node_error(invocation.id, stacktrace)
|
queue_item.session.set_node_error(invocation.id, node_error)
|
||||||
self._services.logger.error(
|
self._services.logger.error(
|
||||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}"
|
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
|
||||||
)
|
)
|
||||||
self._services.logger.error(stacktrace)
|
self._services.logger.error(error_traceback)
|
||||||
|
|
||||||
# Send error event
|
# Send error event
|
||||||
self._services.events.emit_invocation_error(
|
self._services.events.emit_invocation_error(
|
||||||
@ -225,14 +230,21 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
graph_execution_state_id=queue_item.session.id,
|
graph_execution_state_id=queue_item.session.id,
|
||||||
node=invocation.model_dump(),
|
node=invocation.model_dump(),
|
||||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
error_type=exc_type.__name__,
|
error_type=error_type,
|
||||||
error=stacktrace,
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
user_id=getattr(queue_item, "user_id", None),
|
user_id=getattr(queue_item, "user_id", None),
|
||||||
project_id=getattr(queue_item, "project_id", None),
|
project_id=getattr(queue_item, "project_id", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
for callback in self._on_node_error_callbacks:
|
for callback in self._on_node_error_callbacks:
|
||||||
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
callback(
|
||||||
|
invocation=invocation,
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
@ -374,16 +386,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.session_runner.run(queue_item=self._queue_item)
|
self.session_runner.run(queue_item=self._queue_item)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
error_type = e.__class__.__name__
|
||||||
exc_traceback = e.__traceback__
|
error_message = str(e)
|
||||||
assert exc_traceback is not None
|
error_traceback = traceback.format_exc()
|
||||||
self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback)
|
self._on_non_fatal_processor_error(
|
||||||
# Immediately poll for next queue item
|
queue_item=self._queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
# Wait for next polling interval or event to try again
|
||||||
poll_now_event.wait(self._polling_interval)
|
poll_now_event.wait(self._polling_interval)
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# Fatal error in processor, log and pass - we're done here
|
# Fatal error in processor, log and pass - we're done here
|
||||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
error_type = e.__class__.__name__
|
||||||
|
error_message = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
|
||||||
|
self._invoker.services.logger.error(error_traceback)
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
stop_event.clear()
|
stop_event.clear()
|
||||||
@ -394,19 +415,29 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
def _on_non_fatal_processor_error(
|
def _on_non_fatal_processor_error(
|
||||||
self,
|
self,
|
||||||
queue_item: Optional[SessionQueueItem],
|
queue_item: Optional[SessionQueueItem],
|
||||||
exc_type: type,
|
error_type: str,
|
||||||
exc_value: BaseException,
|
error_message: str,
|
||||||
exc_traceback: TracebackType,
|
error_traceback: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
|
||||||
# Non-fatal error in processor
|
# Non-fatal error in processor
|
||||||
self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}")
|
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
|
||||||
self._invoker.services.logger.error(stacktrace)
|
self._invoker.services.logger.error(error_traceback)
|
||||||
|
|
||||||
if queue_item is not None:
|
if queue_item is not None:
|
||||||
# Update the queue item with the completed session
|
# Update the queue item with the completed session
|
||||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||||
# And cancel the queue item with an error
|
# Fail the queue item
|
||||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
self._invoker.services.session_queue.fail_queue_item(
|
||||||
|
item_id=queue_item.item_id,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|
||||||
for callback in self._on_non_fatal_processor_error_callbacks:
|
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||||
callback(exc_type, exc_value, exc_traceback, queue_item)
|
callback(
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user