From 45d2504c1e1bd966fb79eb3b4b1880a21ec3ae77 Mon Sep 17 00:00:00 2001 From: brandonrising Date: Thu, 16 May 2024 13:30:04 -0400 Subject: [PATCH] Break apart session processor and the running of each session into separate classes --- .../session_processor_base.py | 28 ++ .../session_processor_default.py | 339 +++++++++++------- 2 files changed, 236 insertions(+), 131 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 485ef2f8c3..745430d201 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,34 @@ from abc import ABC, abstractmethod +from threading import Event +from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus +from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem + +class SessionRunnerBase(ABC): + """ + Base class for session runner. + """ + + @abstractmethod + def start(self, services: InvocationServices, cancel_event: Event) -> None: + """Starts the session runner""" + pass + + @abstractmethod + def run(self, queue_item: SessionQueueItem) -> None: + """Runs the session""" + pass + + @abstractmethod + def complete(self, queue_item: SessionQueueItem) -> None: + """Completes the session""" + pass + + @abstractmethod + def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None: + """Runs an already prepared node on the session""" + pass class SessionProcessorBase(ABC): diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 2a0ebc3168..eeb8df4cad 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -2,7 +2,7 @@ import traceback from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from typing import Optional +from typing import Callable, Optional, Union from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent @@ -16,15 +16,207 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat from invokeai.app.util.profiler import Profiler from ..invoker import Invoker -from .session_processor_base import SessionProcessorBase +from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase from .session_processor_common import SessionProcessorStatus +class DefaultSessionRunner(SessionRunnerBase): + """Processes a single session's invocations""" + + def __init__( + self, + on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + ): + self.on_before_run_node = on_before_run_node + self.on_after_run_node = on_after_run_node + + def start(self, services: InvocationServices, cancel_event: ThreadEvent): + """Start the session runner""" + self.services = services + self.cancel_event = cancel_event + + def next_invocation( + self, previous_invocation: Optional[BaseInvocation], queue_item: SessionQueueItem, cancel_event: ThreadEvent + ) -> Optional[BaseInvocation]: + invocation = None + if not (queue_item.session.is_complete() or cancel_event.is_set()): + try: + invocation = queue_item.session.next() + except Exception as exc: + self.services.logger.error("ERROR: %s" % exc, exc_info=True) + + node_error = str(exc) + + # Save error + if previous_invocation is not None: + queue_item.session.set_node_error(previous_invocation.id, node_error) + + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=previous_invocation.model_dump() if previous_invocation else {}, + source_node_id=queue_item.session.prepared_source_mapping[previous_invocation.id] + if previous_invocation + else "", + error_type=exc.__class__.__name__, + error=node_error, + user_id=None, + project_id=None, + ) + + if queue_item.session.is_complete() or cancel_event.is_set(): + # Set the invocation to None to prepare for the next session + invocation = None + return invocation + + def run(self, queue_item: SessionQueueItem): + """Run the graph""" + if not queue_item.session: + raise ValueError("Queue item has no session") + invocation = None + # Loop over invocations until the session is complete or canceled + while self.next_invocation(invocation, queue_item, self.cancel_event) and not self.cancel_event.is_set(): + # Prepare the next node + invocation = queue_item.session.next() + if invocation is None: + # If there are no more invocations, complete the graph + break + # Build invocation context (the node-facing API + self.run_node(invocation.id, queue_item) + self.complete(queue_item) + + def complete(self, queue_item: SessionQueueItem): + # Send complete event + self.services.events.emit_graph_execution_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self.services.performance_statistics.log_stats(queue_item.session.id) + self.services.performance_statistics.reset_stats() + + def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run before a node is executed""" + # Send starting event + self.services.events.emit_invocation_started( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session_id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + ) + if self.on_before_run_node is not None: + self.on_before_run_node(invocation, queue_item) + + def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run after a node is executed""" + if self.on_after_run_node is not None: + self.on_after_run_node(invocation, queue_item) + + def run_node(self, node_id: str, queue_item: SessionQueueItem): + """Run a single node in the graph""" + # If this error raises a NodeNotFoundError that's handled by the processor + invocation = queue_item.session.execution_graph.get_node(node_id) + try: + self._on_before_run_node(invocation, queue_item) + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): + context = build_invocation_context( + data=data, + services=self.services, + cancel_event=self.cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal(context=context, services=self.services) + + # Save outputs and history + queue_item.session.complete(invocation.id, outputs) + + self._on_after_run_node(invocation, queue_item) + # Send complete event on successful runs + self.services.events.emit_invocation_complete( + queue_batch_id=queue_item.batch_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=data.source_invocation_id, + result=outputs.model_dump(), + ) + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass + except Exception as e: + error = traceback.format_exc() + + # Save error + queue_item.session.set_node_error(invocation.id, error) + self.services.logger.error( + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + ) + self.services.logger.error(error) + + # Send error event + self.services.events.emit_invocation_error( + queue_batch_id=queue_item.session_id, + queue_item_id=queue_item.item_id, + queue_id=queue_item.queue_id, + graph_execution_state_id=queue_item.session.id, + node=invocation.model_dump(), + source_node_id=queue_item.session.prepared_source_mapping[invocation.id], + error_type=e.__class__.__name__, + error=error, + user_id=None, + project_id=None, + ) + + class DefaultSessionProcessor(SessionProcessorBase): - def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: + def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None: + super().__init__() + self.session_runner = session_runner if session_runner else DefaultSessionRunner() + + def start( + self, + invoker: Invoker, + thread_limit: int = 1, + polling_interval: int = 1, + on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, + on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, + ) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None + self.on_before_run_session = on_before_run_session + self.on_after_run_session = on_after_run_session self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -49,6 +241,7 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) + self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event) self._thread = Thread( name="session_processor", target=self._process, @@ -142,141 +335,25 @@ class DefaultSessionProcessor(SessionProcessorBase): self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() + # If we have a on_before_run_session callback, call it + if self.on_before_run_session is not None: + self.on_before_run_session(self._queue_item) + # If profiling is enabled, start the profiler if self._profiler is not None: self._profiler.start(profile_id=self._queue_item.session_id) - # Prepare invocations and take the first - self._invocation = self._queue_item.session.next() + # Run the graph + self.session_runner.run(queue_item=self._queue_item) - # Loop over invocations until the session is complete or canceled - while self._invocation is not None and not cancel_event.is_set(): - # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=self._queue_item.session.id, output_path=stats_path ) - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - try: - with self._invoker.services.performance_statistics.collect_stats( - self._invocation, self._queue_item.session.id - ): - # Build invocation context (the node-facing API) - data = InvocationContextData( - invocation=self._invocation, - source_invocation_id=source_invocation_id, - queue_item=self._queue_item, - ) - context = build_invocation_context( - data=data, - services=self._invoker.services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = self._invocation.invoke_internal( - context=context, services=self._invoker.services - ) - - # Save outputs and history - self._queue_item.session.complete(self._invocation.id, outputs) - - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), - ) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - - except Exception as e: - error = traceback.format_exc() - - # Save error - self._queue_item.session.set_node_error(self._invocation.id, error) - self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" - ) - self._invoker.services.logger.error(error) - - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, - user_id=None, - project_id=None, - ) - pass - - # The session is complete if the all invocations are complete or there was an error - if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.session_queue.set_queue_item_session( - self._queue_item.item_id, self._queue_item.session - ) - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) - self._invoker.services.performance_statistics.reset_stats() - - # Set the invocation to None to prepare for the next session - self._invocation = None - else: - # Prepare the next invocation - self._invocation = self._queue_item.session.next() - else: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue except Exception: # Non-fatal error in processor self._invoker.services.logger.error(