mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
docs(processor): update docstrings, comments
This commit is contained in:
parent
6063487b20
commit
a98ddedb95
@ -16,17 +16,33 @@ class SessionRunnerBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
|
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
|
||||||
"""Starts the session runner"""
|
"""Starts the session runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
services: The invocation services.
|
||||||
|
cancel_event: The cancel event.
|
||||||
|
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
|
||||||
|
stats will be still be recorded and logged when profiling is disabled.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, queue_item: SessionQueueItem) -> None:
|
def run(self, queue_item: SessionQueueItem) -> None:
|
||||||
"""Runs the session"""
|
"""Runs a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session to run.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||||
"""Runs an already prepared node on the session"""
|
"""Run a single node in the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation to run.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -56,13 +72,25 @@ class SessionProcessorBase(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class OnBeforeRunNode(Protocol):
|
class OnBeforeRunNode(Protocol):
|
||||||
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
|
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Callback to run before executing a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation that will be executed.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class OnAfterRunNode(Protocol):
|
class OnAfterRunNode(Protocol):
|
||||||
def __call__(
|
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
|
||||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
"""Callback to run before executing a node.
|
||||||
) -> bool: ...
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation that was executed.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class OnNodeError(Protocol):
|
class OnNodeError(Protocol):
|
||||||
@ -73,15 +101,37 @@ class OnNodeError(Protocol):
|
|||||||
error_type: str,
|
error_type: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
error_traceback: str,
|
error_traceback: str,
|
||||||
) -> bool: ...
|
) -> None:
|
||||||
|
"""Callback to run when a node has an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation that errored.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
error_type: The type of error, e.g. "ValueError".
|
||||||
|
error_message: The error message, e.g. "Invalid value".
|
||||||
|
error_traceback: The stringified error traceback.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class OnBeforeRunSession(Protocol):
|
class OnBeforeRunSession(Protocol):
|
||||||
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
|
def __call__(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Callback to run before executing a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class OnAfterRunSession(Protocol):
|
class OnAfterRunSession(Protocol):
|
||||||
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
|
def __call__(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Callback to run after executing a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class OnNonFatalProcessorError(Protocol):
|
class OnNonFatalProcessorError(Protocol):
|
||||||
@ -91,4 +141,13 @@ class OnNonFatalProcessorError(Protocol):
|
|||||||
error_type: str,
|
error_type: str,
|
||||||
error_message: str,
|
error_message: str,
|
||||||
error_traceback: str,
|
error_traceback: str,
|
||||||
) -> bool: ...
|
) -> None:
|
||||||
|
"""Callback to run when a non-fatal error occurs in the processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session queue item, if one was being executed when the error occurred.
|
||||||
|
error_type: The type of error, e.g. "ValueError".
|
||||||
|
error_message: The error message, e.g. "Invalid value".
|
||||||
|
error_traceback: The stringified error traceback.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
@ -30,7 +30,7 @@ from .session_processor_common import SessionProcessorStatus
|
|||||||
|
|
||||||
|
|
||||||
class DefaultSessionRunner(SessionRunnerBase):
|
class DefaultSessionRunner(SessionRunnerBase):
|
||||||
"""Processes a single session's invocations"""
|
"""Processes a single session's invocations."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -40,6 +40,15 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
|
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
|
||||||
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
|
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
on_before_run_session_callbacks: Callbacks to run before the session starts.
|
||||||
|
on_before_run_node_callbacks: Callbacks to run before each node starts.
|
||||||
|
on_after_run_node_callbacks: Callbacks to run after each node completes.
|
||||||
|
on_node_error_callbacks: Callbacks to run when a node errors.
|
||||||
|
on_after_run_session_callbacks: Callbacks to run after the session completes.
|
||||||
|
"""
|
||||||
|
|
||||||
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
|
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
|
||||||
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
|
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
|
||||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||||
@ -47,14 +56,12 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||||
|
|
||||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||||
"""Start the session runner"""
|
|
||||||
self._services = services
|
self._services = services
|
||||||
self._cancel_event = cancel_event
|
self._cancel_event = cancel_event
|
||||||
self._profiler = profiler
|
self._profiler = profiler
|
||||||
|
|
||||||
def run(self, queue_item: SessionQueueItem):
|
def run(self, queue_item: SessionQueueItem):
|
||||||
"""Run the graph"""
|
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
|
||||||
# Exceptions raised outside `run_node` are handled by the processor.
|
|
||||||
|
|
||||||
self._on_before_run_session(queue_item=queue_item)
|
self._on_before_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
@ -78,14 +85,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
|
|
||||||
if invocation is None or self._cancel_event.is_set():
|
if invocation is None or self._cancel_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
self.run_node(invocation, queue_item)
|
self.run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
# The session is complete if all invocations have been run or there is an error on the session.
|
||||||
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
self._on_after_run_session(queue_item=queue_item)
|
self._on_after_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run a single node in the graph"""
|
|
||||||
try:
|
try:
|
||||||
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
||||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||||
@ -110,7 +119,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self._on_after_run_node(invocation, queue_item, output)
|
self._on_after_run_node(invocation, queue_item, output)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# TODO(MM2): Create an event for this
|
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
|
||||||
pass
|
pass
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||||
@ -137,6 +146,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Run before a session is executed"""
|
||||||
|
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
if self._profiler is not None:
|
if self._profiler is not None:
|
||||||
self._profiler.start(profile_id=queue_item.session_id)
|
self._profiler.start(profile_id=queue_item.session_id)
|
||||||
@ -145,6 +156,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
callback(queue_item=queue_item)
|
callback(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Run after a session is executed"""
|
||||||
|
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
if self._profiler is not None:
|
if self._profiler is not None:
|
||||||
profile_path = self._profiler.stop()
|
profile_path = self._profiler.stop()
|
||||||
@ -156,7 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
# Update the queue item with the completed session
|
# Update the queue item with the completed session
|
||||||
self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||||
|
|
||||||
# Send complete event
|
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
|
||||||
|
# Send complete event. The events service will receive this and update the queue item's status.
|
||||||
self._services.events.emit_graph_execution_complete(
|
self._services.events.emit_graph_execution_complete(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
@ -175,6 +189,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
|
|
||||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run before a node is executed"""
|
"""Run before a node is executed"""
|
||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self._services.events.emit_invocation_started(
|
self._services.events.emit_invocation_started(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
@ -192,6 +207,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||||
):
|
):
|
||||||
"""Run after a node is executed"""
|
"""Run after a node is executed"""
|
||||||
|
|
||||||
# Send complete event on successful runs
|
# Send complete event on successful runs
|
||||||
self._services.events.emit_invocation_complete(
|
self._services.events.emit_invocation_complete(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
@ -214,6 +230,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
error_message: str,
|
error_message: str,
|
||||||
error_traceback: str,
|
error_traceback: str,
|
||||||
):
|
):
|
||||||
|
"""Run when a node errors"""
|
||||||
|
|
||||||
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
|
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
|
||||||
node_error = f"{error_type}: {error_message}"
|
node_error = f"{error_type}: {error_message}"
|
||||||
queue_item.session.set_node_error(invocation.id, node_error)
|
queue_item.session.set_node_error(invocation.id, node_error)
|
||||||
@ -356,8 +374,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
resume_event: ThreadEvent,
|
resume_event: ThreadEvent,
|
||||||
cancel_event: ThreadEvent,
|
cancel_event: ThreadEvent,
|
||||||
):
|
):
|
||||||
# Outermost processor try block; any unhandled exception is a fatal processor error
|
|
||||||
try:
|
try:
|
||||||
|
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
|
||||||
self._thread_semaphore.acquire()
|
self._thread_semaphore.acquire()
|
||||||
stop_event.clear()
|
stop_event.clear()
|
||||||
resume_event.set()
|
resume_event.set()
|
||||||
@ -365,8 +383,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
|
||||||
try:
|
try:
|
||||||
|
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
|
||||||
# If we are paused, wait for resume event
|
# If we are paused, wait for resume event
|
||||||
resume_event.wait()
|
resume_event.wait()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user