Refractor session runner, move profiling back to processor, create abstract class for session runners, create path for passing in custom session runner to default session processor

This commit is contained in:
Brandon Rising 2024-03-05 16:00:34 -05:00 committed by Brandon
parent 46c904d08a
commit 71ee28ac12
2 changed files with 88 additions and 56 deletions

View File

@ -1,6 +1,35 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC): class SessionProcessorBase(ABC):

View File

@ -17,34 +17,30 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase from .session_processor_base import SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
class SessionRunner: class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations""" """Processes a single session's invocations"""
def __init__( def __init__(
self, self,
services: InvocationServices,
cancel_event: ThreadEvent,
profiler: Union[Profiler, None] = None,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
): ):
self.services = services
self.profiler = profiler
self.cancel_event = cancel_event
self.on_before_run_node = on_before_run_node self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem): def run(self, queue_item: SessionQueueItem):
"""Run the graph""" """Run the graph"""
if not queue_item.session: if not queue_item.session:
raise ValueError("Queue item has no session") raise ValueError("Queue item has no session")
# If profiling is enabled, start the profiler
if self.profiler is not None:
self.profiler.start(profile_id=queue_item.session_id)
# Loop over invocations until the session is complete or canceled # Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()): while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node # Prepare the next node
@ -53,7 +49,7 @@ class SessionRunner:
# If there are no more invocations, complete the graph # If there are no more invocations, complete the graph
break break
# Build invocation context (the node-facing API # Build invocation context (the node-facing API
self.run_node(invocation, queue_item) self.run_node(invocation.id, queue_item)
self.complete(queue_item) self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem): def complete(self, queue_item: SessionQueueItem):
@ -64,41 +60,38 @@ class SessionRunner:
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id, graph_execution_state_id=queue_item.session.id,
) )
# If we are profiling, stop the profiler and dump the profile & stats
if self.profiler:
profile_path = self.profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self.services.performance_statistics.log_stats(queue_item.session.id)
self.services.performance_statistics.reset_stats()
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph""" """Run before a node is executed"""
# If we have a on_before_run_node callback, call it # Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None: if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item) self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try: try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData( data = InvocationContextData(
invocation=invocation, invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item, queue_item=queue_item,
) )
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
context = build_invocation_context( context = build_invocation_context(
@ -113,7 +106,8 @@ class SessionRunner:
# Save outputs and history # Save outputs and history
queue_item.session.complete(invocation.id, outputs) queue_item.session.complete(invocation.id, outputs)
# Send complete event self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete( self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id, queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
@ -159,23 +153,20 @@ class SessionRunner:
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )
pass
finally:
# If we have a on_after_run_node callback, call it
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
"""Processes sessions from the session queue""" """Processes sessions from the session queue"""
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start( def start(
self, self,
invoker: Invoker, invoker: Invoker,
thread_limit: int = 1, thread_limit: int = 1,
polling_interval: int = 1, polling_interval: int = 1,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None: ) -> None:
@ -208,14 +199,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None else None
) )
self.session_runner = SessionRunner(
services=self._invoker.services,
cancel_event=self._cancel_event,
profiler=self._profiler,
on_before_run_node=on_before_run_node,
on_after_run_node=on_after_run_node,
)
self._thread = Thread( self._thread = Thread(
name="session_processor", name="session_processor",
target=self._process, target=self._process,
@ -226,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event, "cancel_event": self._cancel_event,
}, },
) )
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start() self._thread.start()
def stop(self, *args, **kwargs) -> None: def stop(self, *args, **kwargs) -> None:
@ -281,16 +265,35 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Get the next session to process # Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue() self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None and resume_event.is_set(): if self._queue_item is not None and resume_event.is_set():
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear() cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Run the graph # Run the graph
self.session_runner.run(queue_item=self._queue_item) self.session_runner.run(queue_item=self._queue_item)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# If we have a on_after_run_session callback, call it # If we have a on_after_run_session callback, call it
if self.on_after_run_session is not None: if self.on_after_run_session is not None:
self.on_after_run_session(self._queue_item) self.on_after_run_session(self._queue_item)