mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refractor session runner, move profiling back to processor, create abstract class for session runners, create path for passing in custom session runner to default session processor
This commit is contained in:
parent
46c904d08a
commit
71ee28ac12
@ -1,6 +1,35 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for session runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start(self, services: InvocationServices, cancel_event: Event) -> None:
|
||||||
|
"""Starts the session runner"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Runs the session"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def complete(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Completes the session"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_node(self, node_id: str, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Runs an already prepared node on the session"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SessionProcessorBase(ABC):
|
class SessionProcessorBase(ABC):
|
||||||
|
@ -17,34 +17,30 @@ from invokeai.app.services.shared.invocation_context import InvocationContextDat
|
|||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
from .session_processor_base import SessionProcessorBase
|
from .session_processor_base import SessionProcessorBase, SessionRunnerBase
|
||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
class SessionRunner:
|
class DefaultSessionRunner(SessionRunnerBase):
|
||||||
"""Processes a single session's invocations"""
|
"""Processes a single session's invocations"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
services: InvocationServices,
|
|
||||||
cancel_event: ThreadEvent,
|
|
||||||
profiler: Union[Profiler, None] = None,
|
|
||||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
||||||
):
|
):
|
||||||
self.services = services
|
|
||||||
self.profiler = profiler
|
|
||||||
self.cancel_event = cancel_event
|
|
||||||
self.on_before_run_node = on_before_run_node
|
self.on_before_run_node = on_before_run_node
|
||||||
self.on_after_run_node = on_after_run_node
|
self.on_after_run_node = on_after_run_node
|
||||||
|
|
||||||
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent):
|
||||||
|
"""Start the session runner"""
|
||||||
|
self.services = services
|
||||||
|
self.cancel_event = cancel_event
|
||||||
|
|
||||||
def run(self, queue_item: SessionQueueItem):
|
def run(self, queue_item: SessionQueueItem):
|
||||||
"""Run the graph"""
|
"""Run the graph"""
|
||||||
if not queue_item.session:
|
if not queue_item.session:
|
||||||
raise ValueError("Queue item has no session")
|
raise ValueError("Queue item has no session")
|
||||||
# If profiling is enabled, start the profiler
|
|
||||||
if self.profiler is not None:
|
|
||||||
self.profiler.start(profile_id=queue_item.session_id)
|
|
||||||
# Loop over invocations until the session is complete or canceled
|
# Loop over invocations until the session is complete or canceled
|
||||||
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
while not (queue_item.session.is_complete() or self.cancel_event.is_set()):
|
||||||
# Prepare the next node
|
# Prepare the next node
|
||||||
@ -53,7 +49,7 @@ class SessionRunner:
|
|||||||
# If there are no more invocations, complete the graph
|
# If there are no more invocations, complete the graph
|
||||||
break
|
break
|
||||||
# Build invocation context (the node-facing API
|
# Build invocation context (the node-facing API
|
||||||
self.run_node(invocation, queue_item)
|
self.run_node(invocation.id, queue_item)
|
||||||
self.complete(queue_item)
|
self.complete(queue_item)
|
||||||
|
|
||||||
def complete(self, queue_item: SessionQueueItem):
|
def complete(self, queue_item: SessionQueueItem):
|
||||||
@ -64,41 +60,38 @@ class SessionRunner:
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
graph_execution_state_id=queue_item.session.id,
|
graph_execution_state_id=queue_item.session.id,
|
||||||
)
|
)
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
|
||||||
if self.profiler:
|
|
||||||
profile_path = self.profiler.stop()
|
|
||||||
stats_path = profile_path.with_suffix(".json")
|
|
||||||
self.services.performance_statistics.dump_stats(
|
|
||||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
|
||||||
)
|
|
||||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
|
||||||
# we don't care about that - suppress the error.
|
|
||||||
with suppress(GESStatsNotFoundError):
|
|
||||||
self.services.performance_statistics.log_stats(queue_item.session.id)
|
|
||||||
self.services.performance_statistics.reset_stats()
|
|
||||||
|
|
||||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run a single node in the graph"""
|
"""Run before a node is executed"""
|
||||||
# If we have a on_before_run_node callback, call it
|
# Send starting event
|
||||||
|
self.services.events.emit_invocation_started(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session_id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
)
|
||||||
if self.on_before_run_node is not None:
|
if self.on_before_run_node is not None:
|
||||||
self.on_before_run_node(invocation, queue_item)
|
self.on_before_run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
def _on_after_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
|
"""Run after a node is executed"""
|
||||||
|
if self.on_after_run_node is not None:
|
||||||
|
self.on_after_run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
def run_node(self, node_id: str, queue_item: SessionQueueItem):
|
||||||
|
"""Run a single node in the graph"""
|
||||||
|
# If this error raises a NodeNotFoundError that's handled by the processor
|
||||||
|
invocation = queue_item.session.execution_graph.get_node(node_id)
|
||||||
try:
|
try:
|
||||||
|
self._on_before_run_node(invocation, queue_item)
|
||||||
data = InvocationContextData(
|
data = InvocationContextData(
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
queue_item=queue_item,
|
queue_item=queue_item,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send starting event
|
|
||||||
self.services.events.emit_invocation_started(
|
|
||||||
queue_batch_id=queue_item.batch_id,
|
|
||||||
queue_item_id=queue_item.item_id,
|
|
||||||
queue_id=queue_item.queue_id,
|
|
||||||
graph_execution_state_id=queue_item.session_id,
|
|
||||||
node=invocation.model_dump(),
|
|
||||||
source_node_id=data.source_invocation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||||
context = build_invocation_context(
|
context = build_invocation_context(
|
||||||
@ -113,7 +106,8 @@ class SessionRunner:
|
|||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
queue_item.session.complete(invocation.id, outputs)
|
queue_item.session.complete(invocation.id, outputs)
|
||||||
|
|
||||||
# Send complete event
|
self._on_after_run_node(invocation, queue_item)
|
||||||
|
# Send complete event on successful runs
|
||||||
self.services.events.emit_invocation_complete(
|
self.services.events.emit_invocation_complete(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
@ -159,23 +153,20 @@ class SessionRunner:
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
# If we have a on_after_run_node callback, call it
|
|
||||||
if self.on_after_run_node is not None:
|
|
||||||
self.on_after_run_node(invocation, queue_item)
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
"""Processes sessions from the session queue"""
|
"""Processes sessions from the session queue"""
|
||||||
|
|
||||||
|
def __init__(self, session_runner: Union[SessionRunnerBase, None] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||||
|
|
||||||
def start(
|
def start(
|
||||||
self,
|
self,
|
||||||
invoker: Invoker,
|
invoker: Invoker,
|
||||||
thread_limit: int = 1,
|
thread_limit: int = 1,
|
||||||
polling_interval: int = 1,
|
polling_interval: int = 1,
|
||||||
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
|
||||||
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
|
|
||||||
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||||
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -208,14 +199,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
self.session_runner = SessionRunner(
|
|
||||||
services=self._invoker.services,
|
|
||||||
cancel_event=self._cancel_event,
|
|
||||||
profiler=self._profiler,
|
|
||||||
on_before_run_node=on_before_run_node,
|
|
||||||
on_after_run_node=on_after_run_node,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._thread = Thread(
|
self._thread = Thread(
|
||||||
name="session_processor",
|
name="session_processor",
|
||||||
target=self._process,
|
target=self._process,
|
||||||
@ -226,6 +209,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
"cancel_event": self._cancel_event,
|
"cancel_event": self._cancel_event,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
@ -281,16 +265,35 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
# Get the next session to process
|
# Get the next session to process
|
||||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||||
if self._queue_item is not None and resume_event.is_set():
|
if self._queue_item is not None and resume_event.is_set():
|
||||||
# If we have a on_before_run_session callback, call it
|
|
||||||
if self.on_before_run_session is not None:
|
|
||||||
self.on_before_run_session(self._queue_item)
|
|
||||||
|
|
||||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||||
cancel_event.clear()
|
cancel_event.clear()
|
||||||
|
|
||||||
|
# If we have a on_before_run_session callback, call it
|
||||||
|
if self.on_before_run_session is not None:
|
||||||
|
self.on_before_run_session(self._queue_item)
|
||||||
|
|
||||||
|
# If profiling is enabled, start the profiler
|
||||||
|
if self._profiler is not None:
|
||||||
|
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||||
|
|
||||||
# Run the graph
|
# Run the graph
|
||||||
self.session_runner.run(queue_item=self._queue_item)
|
self.session_runner.run(queue_item=self._queue_item)
|
||||||
|
|
||||||
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
|
if self._profiler:
|
||||||
|
profile_path = self._profiler.stop()
|
||||||
|
stats_path = profile_path.with_suffix(".json")
|
||||||
|
self._invoker.services.performance_statistics.dump_stats(
|
||||||
|
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||||
|
# we don't care about that - suppress the error.
|
||||||
|
with suppress(GESStatsNotFoundError):
|
||||||
|
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||||
|
self._invoker.services.performance_statistics.reset_stats()
|
||||||
|
|
||||||
# If we have a on_after_run_session callback, call it
|
# If we have a on_after_run_session callback, call it
|
||||||
if self.on_after_run_session is not None:
|
if self.on_after_run_session is not None:
|
||||||
self.on_after_run_session(self._queue_item)
|
self.on_after_run_session(self._queue_item)
|
||||||
|
Loading…
Reference in New Issue
Block a user