diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 9a6c7416f6..1ff45be866 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -29,7 +29,7 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_records import ModelRecordServiceSQL from ..services.names.names_default import SimpleNameService -from ..services.session_processor.session_processor_default import DefaultSessionProcessor +from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage @@ -103,7 +103,41 @@ class ApiDependencies: ) names = SimpleNameService() performance_statistics = InvocationStatsService() - session_processor = DefaultSessionProcessor() + + def on_before_run_session(queue_item): + print("BEFORE RUN SESSION", queue_item.item_id) + return True + + def on_before_run_node(invocation, queue_item): + print("BEFORE RUN NODE", invocation.id) + return True + + def on_after_run_node(invocation, queue_item, outputs): + print("AFTER RUN NODE", invocation.id) + return True + + def on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback): + print("NODE ERROR", invocation.id) + return True + + def on_after_run_session(queue_item): + print("AFTER RUN SESSION", queue_item.item_id) + return True + + def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback): + print("NON FATAL PROCESSOR ERROR", exc_value) + return True + + session_processor = DefaultSessionProcessor( + DefaultSessionRunner( + on_before_run_session, + on_before_run_node, + on_after_run_node, + on_node_error, + on_after_run_session, + ), + on_non_fatal_processor_error, + ) session_queue = SqliteSessionQueue(db=db) urls = LocalUrlService() workflow_records = SqliteWorkflowRecordsStorage(db=db) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 7a67c3ab2c..7140847518 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from threading import Event +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem @@ -22,12 +23,7 @@ class SessionRunnerBase(ABC): pass @abstractmethod - def complete(self, queue_item: SessionQueueItem) -> None: - """Completes the session""" - pass - - @abstractmethod - def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None: + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: """Runs an already prepared node on the session""" pass diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7f5a107b83..33274dd97b 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -2,12 +2,13 @@ import traceback from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from typing import Callable, Optional, Union +from types import TracebackType +from typing import Callable, Optional, TypeAlias from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent -from invokeai.app.invocations.baseinvocation import BaseInvocation +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_common import CanceledException @@ -19,73 +20,71 @@ from ..invoker import Invoker from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus +OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool] +OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool] +OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool] +OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool] +OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool] +OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool] + + +def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: + return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + class DefaultSessionRunner(SessionRunnerBase): """Processes a single session's invocations""" def __init__( self, - on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, - on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_before_run_session: Optional[OnBeforeRunSession] = None, + on_before_run_node: Optional[OnBeforeRunNode] = None, + on_after_run_node: Optional[OnAfterRunNode] = None, + on_node_error: Optional[OnNodeError] = None, + on_after_run_session: Optional[OnAfterRunSession] = None, ): + self.on_before_run_session = on_before_run_session self.on_before_run_node = on_before_run_node self.on_after_run_node = on_after_run_node + self.on_node_error = on_node_error + self.on_after_run_session = on_after_run_session def start(self, services: InvocationServices, cancel_event: ThreadEvent): """Start the session runner""" self.services = services self.cancel_event = cancel_event - def next_invocation( - self, previous_invocation: Optional[BaseInvocation], queue_item: SessionQueueItem, cancel_event: ThreadEvent - ) -> Optional[BaseInvocation]: - invocation = None - if not (queue_item.session.is_complete() or cancel_event.is_set()): - try: - invocation = queue_item.session.next() - except Exception as exc: - self.services.logger.error("ERROR: %s" % exc, exc_info=True) - - node_error = str(exc) - - # Save error - if previous_invocation is not None: - queue_item.session.set_node_error(previous_invocation.id, node_error) - - # Send error event - self.services.events.emit_invocation_error( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=previous_invocation.model_dump() if previous_invocation else {}, - source_node_id=queue_item.session.prepared_source_mapping[previous_invocation.id] - if previous_invocation - else "", - error_type=exc.__class__.__name__, - error=node_error, - user_id=None, - project_id=None, - ) - - if queue_item.session.is_complete() or cancel_event.is_set(): - # Set the invocation to None to prepare for the next session - invocation = None - return invocation - - def run(self, queue_item: SessionQueueItem): + def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None): """Run the graph""" - if not queue_item.session: - raise ValueError("Queue item has no session") - invocation = None # Loop over invocations until the session is complete or canceled - invocation = self.next_invocation(invocation, queue_item, self.cancel_event) - while invocation is not None and not self.cancel_event.is_set(): - self.run_node(invocation.id, queue_item) - invocation = self.next_invocation(invocation, queue_item, self.cancel_event) - self.complete(queue_item) - def complete(self, queue_item: SessionQueueItem): + self._on_before_run_session(queue_item=queue_item) + while True: + invocation = queue_item.session.next() + if invocation is None or self.cancel_event.is_set(): + break + self.run_node(invocation, queue_item) + if queue_item.session.is_complete() or self.cancel_event.is_set(): + break + self._on_after_run_session(queue_item=queue_item) + + def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + # If profiling is enabled, start the profiler + if profiler is not None: + profiler.start(profile_id=queue_item.session_id) + + if self.on_before_run_session: + self.on_before_run_session(queue_item) + + def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None: + # If we are profiling, stop the profiler and dump the profile & stats + if profiler: + profile_path = profiler.stop() + stats_path = profile_path.with_suffix(".json") + self.services.performance_statistics.dump_stats( + graph_execution_state_id=queue_item.session.id, output_path=stats_path + ) + # Send complete event self.services.events.emit_graph_execution_complete( queue_batch_id=queue_item.batch_id, @@ -93,12 +92,16 @@ class DefaultSessionRunner(SessionRunnerBase): queue_id=queue_item.queue_id, graph_execution_state_id=queue_item.session.id, ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # we don't care about that - suppress the error. with suppress(GESStatsNotFoundError): self.services.performance_statistics.log_stats(queue_item.session.id) self.services.performance_statistics.reset_stats() + if self.on_after_run_session: + self.on_after_run_session(queue_item) + def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" # Send starting event @@ -110,28 +113,73 @@ class DefaultSessionRunner(SessionRunnerBase): node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], ) + # And run lifecycle callbacks if self.on_before_run_node is not None: self.on_before_run_node(invocation, queue_item) - def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + def _on_after_run_node( + self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput + ): """Run after a node is executed""" + # Send complete event on successful runs + self.services.events.emit_invocation_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + result=outputs.model_dump(), + ) + # And run lifecycle callbacks if self.on_after_run_node is not None: - self.on_after_run_node(invocation, queue_item) + self.on_after_run_node(invocation, queue_item, outputs) - def run_node(self, node_id: str, queue_item: SessionQueueItem): + def _on_node_error( + self, + invocation: BaseInvocation, + queue_item: SessionQueueItem, + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ): + stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) + + queue_item.session.set_node_error(invocation.id, stacktrace) + self.services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}" + ) + self.services.logger.error(stacktrace) + + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=exc_type.__name__, + error=stacktrace, + user_id=None, + project_id=None, + ) + + if self.on_node_error is not None: + self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run a single node in the graph""" - # If this error raises a NodeNotFoundError that's handled by the processor - invocation = queue_item.session.execution_graph.get_node(node_id) try: - self._on_before_run_node(invocation, queue_item) - data = InvocationContextData( - invocation=invocation, - source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], - queue_item=queue_item, - ) - - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + # Any unhandled exception is an invocation error & will fail the graph with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): + self._on_before_run_node(invocation, queue_item) + + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) context = build_invocation_context( data=data, services=self.services, @@ -140,21 +188,11 @@ class DefaultSessionRunner(SessionRunnerBase): # Invoke the node outputs = invocation.invoke_internal(context=context, services=self.services) - # Save outputs and history queue_item.session.complete(invocation.id, outputs) - self._on_after_run_node(invocation, queue_item) - # Send complete event on successful runs - self.services.events.emit_invocation_complete( - queue_batch_id=queue_item.batch_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=data.source_invocation_id, - result=outputs.model_dump(), - ) + self._on_after_run_node(invocation, queue_item, outputs) + except KeyboardInterrupt: # TODO(MM2): Create an event for this pass @@ -171,48 +209,51 @@ class DefaultSessionRunner(SessionRunnerBase): # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: - error = traceback.format_exc() - - # Save error - queue_item.session.set_node_error(invocation.id, error) - self.services.logger.error( - f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" - ) - self.services.logger.error(error) - - # Send error event - self.services.events.emit_invocation_error( - queue_batch_id=queue_item.session_id, - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - graph_execution_state_id=queue_item.session.id, - node=invocation.model_dump(), - source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - error_type=e.__class__.__name__, - error=error, - user_id=None, - project_id=None, - ) + exc_type = type(e) + exc_value = e + exc_traceback = e.__traceback__ + assert exc_traceback is not None + self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) class DefaultSessionProcessor(SessionProcessorBase): - def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None: + def __init__( + self, + session_runner: Optional[SessionRunnerBase] = None, + on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None, + ) -> None: super().__init__() + self.session_runner = session_runner if session_runner else DefaultSessionRunner() + self.on_non_fatal_processor_error = on_non_fatal_processor_error + + def _on_non_fatal_processor_error( + self, + queue_item: Optional[SessionQueueItem], + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ) -> None: + stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) + # Non-fatal error in processor + self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") + # Cancel the queue item + if queue_item is not None: + self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) + + if self.on_non_fatal_processor_error: + self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback) def start( self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1, - on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, - on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, ) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None - self.on_before_run_session = on_before_run_session - self.on_after_run_session = on_after_run_session self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -331,40 +372,15 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() - # If we have a on_before_run_session callback, call it - if self.on_before_run_session is not None: - self.on_before_run_session(self._queue_item) - - # If profiling is enabled, start the profiler - if self._profiler is not None: - self._profiler.start(profile_id=self._queue_item.session_id) - # Run the graph self.session_runner.run(queue_item=self._queue_item) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - - except Exception: - # Non-fatal error in processor - self._invoker.services.logger.error( - f"Non-fatal error in session processor:\n{traceback.format_exc()}" - ) - # Cancel the queue item - if self._queue_item is not None: - self._invoker.services.session_queue.set_queue_item_session( - self._queue_item.item_id, self._queue_item.session - ) - self._invoker.services.session_queue.cancel_queue_item( - self._queue_item.item_id, error=traceback.format_exc() - ) - # Reset the invocation to None to prepare for the next session - self._invocation = None + except Exception as e: + exc_type = type(e) + exc_value = e + exc_traceback = e.__traceback__ + assert exc_traceback is not None + self._on_non_fatal_processor_error(self._queue_item, exc_type, exc_value, exc_traceback) # Immediately poll for next queue item poll_now_event.wait(self._polling_interval) continue