Separate the logic that actually runs a graph in the session_processor into its own class

This commit is contained in:
Brandon Rising 2024-03-04 15:23:51 -05:00 committed by Brandon
parent e30cb4b52f
commit afa4df1991
2 changed files with 171 additions and 131 deletions

View File

@ -1,7 +1,5 @@
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from threading import BoundedSemaphore, Thread, Event as ThreadEvent
from typing import Optional
from fastapi_events.handlers.local import local_handler
@ -9,10 +7,8 @@ from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.graph_processor import GraphProcessor
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.util.profiler import Profiler
from ..invoker import Invoker
@ -49,6 +45,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
self.graph_processor = GraphProcessor(
services=self._invoker.services,
cancel_event=self._cancel_event,
profiler=self._profiler,
)
self._thread = Thread(
name="session_processor",
target=self._process,
@ -117,131 +119,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=self._queue_item.session_id)
# Prepare invocations and take the first
self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# Run the graph
self.graph_processor.run(queue_item=self._queue_item)
# The session is complete, immediately poll for next session
self._queue_item = None

View File

@ -0,0 +1,161 @@
import traceback
from contextlib import suppress
from threading import Event
from typing import Callable, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.util.profiler import Profiler
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.session_processor.session_processor_common import CanceledException
class GraphProcessor:
"""Process a graph of invocations"""
def __init__(
self,
services: InvocationServices,
cancel_event: Event,
profiler: Union[Profiler, None] = None,
on_before_run_node: Union[Callable[[BaseInvocation,SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation,SessionQueueItem], bool], None] = None,
):
self.services = services
self.profiler = profiler
self.cancel_event = cancel_event
self.on_before_run_node = on_before_run_node
self.on_after_run_node = on_after_run_node
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
if not queue_item.session:
raise ValueError("Queue item has no session")
# If profiling is enabled, start the profiler
if self.profiler is not None:
self.profiler.start(profile_id=queue_item.session_id)
# Loop over invocations until the session is complete or canceled
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
# Prepare the next node
invocation = queue_item.session.next()
if invocation is None:
# If there are no more invocations, complete the graph
break
# Build invocation context (the node-facing API
self.run_node(invocation, queue_item)
self.complete(queue_item)
def complete(self, queue_item: SessionQueueItem):
"""Complete the graph"""
self.services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self.profiler:
profile_path = self.profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self.services.performance_statistics.log_stats(queue_item.session.id)
self.services.performance_statistics.reset_stats()
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
# If we have a on_before_run_node callback, call it
if self.on_before_run_node is not None:
self.on_before_run_node(invocation, queue_item)
try:
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
# Send starting event
self.services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(
invocation, queue_item.session_id
):
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(
context=context, services=self.services
)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
# Send complete event
self.services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=data.source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
self.cancel_event.set()
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error = traceback.format_exc()
# Save error
queue_item.session.set_node_error(invocation.id, error)
self.services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}"
)
self.services.logger.error(error)
# Send error event
self.services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=e.__class__.__name__,
error=error,
)
pass
finally:
# If we have a on_after_run_node callback, call it
if self.on_after_run_node is not None:
self.on_after_run_node(invocation, queue_item)