feat(app): iterate on processor split

- Add `OnNodeError` and `OnNonFatalProcessorError` callbacks
- Move all session/node callbacks to `SessionRunner` - this ensures we dump perf stats before resetting them and generally makes sense to me
- Remove `complete` event from `SessionRunner`, it's essentially the same as `OnAfterRunSession`
- Remove extraneous `next_invocation` block, which would treat a processor error as a node error
- Simplify loops
- Add some callbacks for testing, to be removed before merge
This commit is contained in:
psychedelicious 2024-05-22 18:33:12 +10:00
parent 82b4298b03
commit be41c84305
3 changed files with 188 additions and 142 deletions

View File

@ -29,7 +29,7 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
@ -103,7 +103,41 @@ class ApiDependencies:
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor()
def on_before_run_session(queue_item):
print("BEFORE RUN SESSION", queue_item.item_id)
return True
def on_before_run_node(invocation, queue_item):
print("BEFORE RUN NODE", invocation.id)
return True
def on_after_run_node(invocation, queue_item, outputs):
print("AFTER RUN NODE", invocation.id)
return True
def on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback):
print("NODE ERROR", invocation.id)
return True
def on_after_run_session(queue_item):
print("AFTER RUN SESSION", queue_item.item_id)
return True
def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback):
print("NON FATAL PROCESSOR ERROR", exc_value)
return True
session_processor = DefaultSessionProcessor(
DefaultSessionRunner(
on_before_run_session,
on_before_run_node,
on_after_run_node,
on_node_error,
on_after_run_session,
),
on_non_fatal_processor_error,
)
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@ -22,12 +23,7 @@ class SessionRunnerBase(ABC):
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass

View File

@ -2,12 +2,13 @@ import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Callable, Optional, Union
from types import TracebackType
from typing import Callable, Optional, TypeAlias
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
@ -19,73 +20,71 @@ from ..invoker import Invoker
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool]
OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool]
OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool]
OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool]
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_before_run_session: Optional[OnBeforeRunSession] = None,
on_before_run_node: Optional[OnBeforeRunNode] = None,
on_after_run_node: Optional[OnAfterRunNode] = None,
on_node_error: Optional[OnNodeError] = None,
on_after_run_session: Optional[OnAfterRunSession] = None,
):
self.on_before_run_session = on_before_run_session
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
self.on_node_error = on_node_error
self.on_after_run_session = on_after_run_session
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def next_invocation(
self, previous_invocation: Optional[BaseInvocation], queue_item: SessionQueueItem, cancel_event: ThreadEvent
) -> Optional[BaseInvocation]:
invocation = None
if not (queue_item.session.is_complete() or cancel_event.is_set()):
try:
def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None):
"""Run the graph"""
# Loop over invocations until the session is complete or canceled
self._on_before_run_session(queue_item=queue_item)
while True:
invocation = queue_item.session.next()
except Exception as exc:
self.services.logger.error("ERROR: %s" % exc, exc_info=True)
if invocation is None or self.cancel_event.is_set():
break
self.run_node(invocation, queue_item)
if queue_item.session.is_complete() or self.cancel_event.is_set():
break
self._on_after_run_session(queue_item=queue_item)
node_error = str(exc)
def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
# If profiling is enabled, start the profiler
if profiler is not None:
profiler.start(profile_id=queue_item.session_id)
# Save error
if previous_invocation is not None:
queue_item.session.set_node_error(previous_invocation.id, node_error)
if self.on_before_run_session:
self.on_before_run_session(queue_item)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=previous_invocation.model_dump() if previous_invocation else {},
source_node_id=queue_item.session.prepared_source_mapping[previous_invocation.id]
if previous_invocation
else "",
error_type=exc.__class__.__name__,
error=node_error,
user_id=None,
project_id=None,
def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
# If we are profiling, stop the profiler and dump the profile & stats
if profiler:
profile_path = profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
if queue_item.session.is_complete() or cancel_event.is_set():
# Set the invocation to None to prepare for the next session
invocation = None
return invocation
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
invocation = None
# Loop over invocations until the session is complete or canceled
invocation = self.next_invocation(invocation, queue_item, self.cancel_event)
while invocation is not None and not self.cancel_event.is_set():
self.run_node(invocation.id, queue_item)
invocation = self.next_invocation(invocation, queue_item, self.cancel_event)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
# Send complete event
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
@ -93,12 +92,16 @@ class DefaultSessionRunner(SessionRunnerBase):
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self.services.performance_statistics.log_stats(queue_item.session.id)
self.services.performance_statistics.reset_stats()
if self.on_after_run_session:
self.on_after_run_session(queue_item)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
@ -110,28 +113,73 @@ class DefaultSessionRunner(SessionRunnerBase):
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
# And run lifecycle callbacks
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
):
"""Run after a node is executed"""
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
result=outputs.model_dump(),
)
# And run lifecycle callbacks
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
self.on_after_run_node(invocation, queue_item, outputs)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
def _on_node_error(
self,
invocation: BaseInvocation,
queue_item: SessionQueueItem,
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
):
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
queue_item.session.set_node_error(invocation.id, stacktrace)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
)
self.services.logger.error(stacktrace)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=exc_type.__name__,
error=stacktrace,
user_id=None,
project_id=None,
)
if self.on_node_error is not None:
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
# Any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context(
data=data,
services=self.services,
@ -140,21 +188,11 @@ class DefaultSessionRunner(SessionRunnerBase):
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
self._on_after_run_node(invocation, queue_item, outputs)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
@ -171,48 +209,51 @@ class DefaultSessionRunner(SessionRunnerBase):
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
user_id=None,
project_id=None,
)
exc_type = type(e)
exc_value = e
exc_traceback = e.__traceback__
assert exc_traceback is not None
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
class DefaultSessionProcessor(SessionProcessorBase):
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
def __init__(
self,
session_runner: Optional[SessionRunnerBase] = None,
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
self.on_non_fatal_processor_error = on_non_fatal_processor_error
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
) -> None:
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
# Cancel the queue item
if queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
if self.on_non_fatal_processor_error:
self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback)
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent()
@ -331,40 +372,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
except Exception:
# Non-fatal error in processor
self._invoker.services.logger.error(
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
)
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)
# Reset the invocation to None to prepare for the next session
self._invocation = None
except Exception as e:
exc_type = type(e)
exc_value = e
exc_traceback = e.__traceback__
assert exc_traceback is not None
self._on_non_fatal_processor_error(self._queue_item, exc_type, exc_value, exc_traceback)
# Immediately poll for next queue item
poll_now_event.wait(self._polling_interval)
continue