mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): iterate on processor split
- Add `OnNodeError` and `OnNonFatalProcessorError` callbacks - Move all session/node callbacks to `SessionRunner` - this ensures we dump perf stats before resetting them and generally makes sense to me - Remove `complete` event from `SessionRunner`, it's essentially the same as `OnAfterRunSession` - Remove extraneous `next_invocation` block, which would treat a processor error as a node error - Simplify loops - Add some callbacks for testing, to be removed before merge
This commit is contained in:
parent
82b4298b03
commit
be41c84305
@ -29,7 +29,7 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
@ -103,7 +103,41 @@ class ApiDependencies:
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
session_processor = DefaultSessionProcessor()
|
||||
|
||||
def on_before_run_session(queue_item):
|
||||
print("BEFORE RUN SESSION", queue_item.item_id)
|
||||
return True
|
||||
|
||||
def on_before_run_node(invocation, queue_item):
|
||||
print("BEFORE RUN NODE", invocation.id)
|
||||
return True
|
||||
|
||||
def on_after_run_node(invocation, queue_item, outputs):
|
||||
print("AFTER RUN NODE", invocation.id)
|
||||
return True
|
||||
|
||||
def on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback):
|
||||
print("NODE ERROR", invocation.id)
|
||||
return True
|
||||
|
||||
def on_after_run_session(queue_item):
|
||||
print("AFTER RUN SESSION", queue_item.item_id)
|
||||
return True
|
||||
|
||||
def on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback):
|
||||
print("NON FATAL PROCESSOR ERROR", exc_value)
|
||||
return True
|
||||
|
||||
session_processor = DefaultSessionProcessor(
|
||||
DefaultSessionRunner(
|
||||
on_before_run_session,
|
||||
on_before_run_node,
|
||||
on_after_run_node,
|
||||
on_node_error,
|
||||
on_after_run_session,
|
||||
),
|
||||
on_non_fatal_processor_error,
|
||||
)
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
|
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
@ -22,12 +23,7 @@ class SessionRunnerBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Completes the session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||
"""Runs an already prepared node on the session"""
|
||||
pass
|
||||
|
||||
|
@ -2,12 +2,13 @@ import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Callable, Optional, Union
|
||||
from types import TracebackType
|
||||
from typing import Callable, Optional, TypeAlias
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
@ -19,73 +20,71 @@ from ..invoker import Invoker
|
||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
OnBeforeRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem], bool]
|
||||
OnAfterRunNode: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, BaseInvocationOutput], bool]
|
||||
OnNodeError: TypeAlias = Callable[[BaseInvocation, SessionQueueItem, type, BaseException, TracebackType], bool]
|
||||
OnBeforeRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
|
||||
OnAfterRunSession: TypeAlias = Callable[[SessionQueueItem], bool]
|
||||
OnNonFatalProcessorError: TypeAlias = Callable[[Optional[SessionQueueItem], type, BaseException, TracebackType], bool]
|
||||
|
||||
|
||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||
|
||||
|
||||
class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Processes a single session's invocations"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||
on_before_run_session: Optional[OnBeforeRunSession] = None,
|
||||
on_before_run_node: Optional[OnBeforeRunNode] = None,
|
||||
on_after_run_node: Optional[OnAfterRunNode] = None,
|
||||
on_node_error: Optional[OnNodeError] = None,
|
||||
on_after_run_session: Optional[OnAfterRunSession] = None,
|
||||
):
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_before_run_node = on_before_run_node
|
||||
self.on_after_run_node = on_after_run_node
|
||||
self.on_node_error = on_node_error
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
|
||||
def next_invocation(
|
||||
self, previous_invocation: Optional[BaseInvocation], queue_item: SessionQueueItem, cancel_event: ThreadEvent
|
||||
) -> Optional[BaseInvocation]:
|
||||
invocation = None
|
||||
if not (queue_item.session.is_complete() or cancel_event.is_set()):
|
||||
try:
|
||||
def run(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None):
|
||||
"""Run the graph"""
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
|
||||
self._on_before_run_session(queue_item=queue_item)
|
||||
while True:
|
||||
invocation = queue_item.session.next()
|
||||
except Exception as exc:
|
||||
self.services.logger.error("ERROR: %s" % exc, exc_info=True)
|
||||
if invocation is None or self.cancel_event.is_set():
|
||||
break
|
||||
self.run_node(invocation, queue_item)
|
||||
if queue_item.session.is_complete() or self.cancel_event.is_set():
|
||||
break
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
node_error = str(exc)
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
|
||||
# If profiling is enabled, start the profiler
|
||||
if profiler is not None:
|
||||
profiler.start(profile_id=queue_item.session_id)
|
||||
|
||||
# Save error
|
||||
if previous_invocation is not None:
|
||||
queue_item.session.set_node_error(previous_invocation.id, node_error)
|
||||
if self.on_before_run_session:
|
||||
self.on_before_run_session(queue_item)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=previous_invocation.model_dump() if previous_invocation else {},
|
||||
source_node_id=queue_item.session.prepared_source_mapping[previous_invocation.id]
|
||||
if previous_invocation
|
||||
else "",
|
||||
error_type=exc.__class__.__name__,
|
||||
error=node_error,
|
||||
user_id=None,
|
||||
project_id=None,
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem, profiler: Optional[Profiler] = None) -> None:
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if profiler:
|
||||
profile_path = profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
if queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Set the invocation to None to prepare for the next session
|
||||
invocation = None
|
||||
return invocation
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
if not queue_item.session:
|
||||
raise ValueError("Queue item has no session")
|
||||
invocation = None
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
invocation = self.next_invocation(invocation, queue_item, self.cancel_event)
|
||||
while invocation is not None and not self.cancel_event.is_set():
|
||||
self.run_node(invocation.id, queue_item)
|
||||
invocation = self.next_invocation(invocation, queue_item, self.cancel_event)
|
||||
self.complete(queue_item)
|
||||
|
||||
def complete(self, queue_item: SessionQueueItem):
|
||||
# Send complete event
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
@ -93,12 +92,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self.services.performance_statistics.reset_stats()
|
||||
|
||||
if self.on_after_run_session:
|
||||
self.on_after_run_session(queue_item)
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
@ -110,28 +113,73 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
# And run lifecycle callbacks
|
||||
if self.on_before_run_node is not None:
|
||||
self.on_before_run_node(invocation, queue_item)
|
||||
|
||||
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
def _on_after_run_node(
|
||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, outputs: BaseInvocationOutput
|
||||
):
|
||||
"""Run after a node is executed"""
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
# And run lifecycle callbacks
|
||||
if self.on_after_run_node is not None:
|
||||
self.on_after_run_node(invocation, queue_item)
|
||||
self.on_after_run_node(invocation, queue_item, outputs)
|
||||
|
||||
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||
def _on_node_error(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
queue_item: SessionQueueItem,
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
):
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
|
||||
queue_item.session.set_node_error(invocation.id, stacktrace)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
|
||||
)
|
||||
self.services.logger.error(stacktrace)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=exc_type.__name__,
|
||||
error=stacktrace,
|
||||
user_id=None,
|
||||
project_id=None,
|
||||
)
|
||||
|
||||
if self.on_node_error is not None:
|
||||
self.on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||
try:
|
||||
# Any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
@ -140,21 +188,11 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item)
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=data.source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
self._on_after_run_node(invocation, queue_item, outputs)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
@ -171,48 +209,51 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
queue_item.session.set_node_error(invocation.id, error)
|
||||
self.services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
|
||||
)
|
||||
self.services.logger.error(error)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
user_id=None,
|
||||
project_id=None,
|
||||
)
|
||||
exc_type = type(e)
|
||||
exc_value = e
|
||||
exc_traceback = e.__traceback__
|
||||
assert exc_traceback is not None
|
||||
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
session_runner: Optional[SessionRunnerBase] = None,
|
||||
on_non_fatal_processor_error: Optional[OnNonFatalProcessorError] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||
self.on_non_fatal_processor_error = on_non_fatal_processor_error
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
) -> None:
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
|
||||
# Cancel the queue item
|
||||
if queue_item is not None:
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||
|
||||
if self.on_non_fatal_processor_error:
|
||||
self.on_non_fatal_processor_error(queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
def start(
|
||||
self,
|
||||
invoker: Invoker,
|
||||
thread_limit: int = 1,
|
||||
polling_interval: int = 1,
|
||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||
) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
self.on_before_run_session = on_before_run_session
|
||||
self.on_after_run_session = on_after_run_session
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
@ -331,40 +372,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
|
||||
# If we have a on_before_run_session callback, call it
|
||||
if self.on_before_run_session is not None:
|
||||
self.on_before_run_session(self._queue_item)
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(
|
||||
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
|
||||
)
|
||||
# Cancel the queue item
|
||||
if self._queue_item is not None:
|
||||
self._invoker.services.session_queue.set_queue_item_session(
|
||||
self._queue_item.item_id, self._queue_item.session
|
||||
)
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
self._queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
# Reset the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
except Exception as e:
|
||||
exc_type = type(e)
|
||||
exc_value = e
|
||||
exc_traceback = e.__traceback__
|
||||
assert exc_traceback is not None
|
||||
self._on_non_fatal_processor_error(self._queue_item, exc_type, exc_value, exc_traceback)
|
||||
# Immediately poll for next queue item
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
Loading…
Reference in New Issue
Block a user