Refractor session runner, move profiling back to processor, create abstract class for session runners, create path for passing in custom session runner to default session processor

This commit is contained in:
Brandon Rising 2024-03-05 16:00:34 -05:00 committed by Brandon
parent 46c904d08a
commit 71ee28ac12
2 changed files with 88 additions and 56 deletions

View File

@ -1,6 +1,35 @@
from abc import ABC, abstractmethod
from threading import Event
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionRunnerBase(ABC):
"""
Base class for session runner.
"""
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event) -> None:
"""Starts the session runner"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
pass
@abstractmethod
def complete(self, queue_item: SessionQueueItem) -> None:
"""Completes the session"""
pass
@abstractmethod
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
pass
class SessionProcessorBase(ABC):

View File

@ -17,34 +17,30 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
from .session_processor_base import SessionProcessorBase
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
from .session_processor_common import SessionProcessorStatus
class SessionRunner:
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
def __init__(
self,
services: InvocationServices,
cancel_event: ThreadEvent,
profiler: Union[Profiler, None] = None,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
):
self.services = services
self.profiler = profiler
self.cancel_event = cancel_event
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# If profiling is enabled, start the profiler
if self.profiler is not None:
self.profiler.start(profile_id=queue_item.session_id)
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
@ -53,7 +49,7 @@ class SessionRunner:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation, queue_item)
self.run_node(invocation.id, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
@ -64,31 +60,9 @@ class SessionRunner:
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self.profiler:
profile_path = self.profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self.services.performance_statistics.log_stats(queue_item.session.id)
self.services.performance_statistics.reset_stats()
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If we have a on_before_run_node callback, call it
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
try:
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
@ -96,7 +70,26 @@ class SessionRunner:
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run after a node is executed"""
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
def run_node(self, node_id: str, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If this error raises a NodeNotFoundError that's handled by the processor
invocation = queue_item.session.execution_graph.get_node(node_id)
try:
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
@ -113,7 +106,8 @@ class SessionRunner:
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
# Send complete event
self._on_after_run_node(invocation, queue_item)
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
@ -159,23 +153,20 @@ class SessionRunner:
error_type=e.__class__.__name__,
error=error,
)
pass
finally:
# If we have a on_after_run_node callback, call it
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)
class DefaultSessionProcessor(SessionProcessorBase):
"""Processes sessions from the session queue"""
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
super().__init__()
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
def start(
self,
invoker: Invoker,
thread_limit: int = 1,
polling_interval: int = 1,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None:
@ -208,14 +199,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self.session_runner = SessionRunner(
services=self._invoker.services,
cancel_event=self._cancel_event,
profiler=self._profiler,
on_before_run_node=on_before_run_node,
on_after_run_node=on_after_run_node,
)
self._thread = Thread(
name="session_processor",
target=self._process,
@ -226,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
"cancel_event": self._cancel_event,
},
)
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
self._thread.start()
def stop(self, *args, **kwargs) -> None:
@ -281,16 +265,35 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None and resume_event.is_set():
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Run the graph
self.session_runner.run(queue_item=self._queue_item)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# If we have a on_after_run_session callback, call it
if self.on_after_run_session is not None:
self.on_after_run_session(self._queue_item)