From a98ddedb9576a4393fc2d34bfd5aff343ff55eb3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 24 May 2024 10:20:20 +1000 Subject: [PATCH] docs(processor): update docstrings, comments --- .../session_processor_base.py | 81 ++++++++++++++++--- .../session_processor_default.py | 36 ++++++--- 2 files changed, 97 insertions(+), 20 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index 1436627a9e..15611bb5f8 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -16,17 +16,33 @@ class SessionRunnerBase(ABC): @abstractmethod def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None: - """Starts the session runner""" + """Starts the session runner. + + Args: + services: The invocation services. + cancel_event: The cancel event. + profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session + stats will be still be recorded and logged when profiling is disabled. + """ pass @abstractmethod def run(self, queue_item: SessionQueueItem) -> None: - """Runs the session""" + """Runs a session. + + Args: + queue_item: The session to run. + """ pass @abstractmethod def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: - """Runs an already prepared node on the session""" + """Run a single node in the graph. + + Args: + invocation: The invocation to run. + queue_item: The session queue item. + """ pass @@ -56,13 +72,25 @@ class SessionProcessorBase(ABC): class OnBeforeRunNode(Protocol): - def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ... + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that will be executed. + queue_item: The session queue item. + """ + ... class OnAfterRunNode(Protocol): - def __call__( - self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput - ) -> bool: ... + def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None: + """Callback to run before executing a node. + + Args: + invocation: The invocation that was executed. + queue_item: The session queue item. + """ + ... class OnNodeError(Protocol): @@ -73,15 +101,37 @@ class OnNodeError(Protocol): error_type: str, error_message: str, error_traceback: str, - ) -> bool: ... + ) -> None: + """Callback to run when a node has an error. + + Args: + invocation: The invocation that errored. + queue_item: The session queue item. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... class OnBeforeRunSession(Protocol): - def __call__(self, queue_item: SessionQueueItem) -> bool: ... + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run before executing a session. + + Args: + queue_item: The session queue item. + """ + ... class OnAfterRunSession(Protocol): - def __call__(self, queue_item: SessionQueueItem) -> bool: ... + def __call__(self, queue_item: SessionQueueItem) -> None: + """Callback to run after executing a session. + + Args: + queue_item: The session queue item. + """ + ... class OnNonFatalProcessorError(Protocol): @@ -91,4 +141,13 @@ class OnNonFatalProcessorError(Protocol): error_type: str, error_message: str, error_traceback: str, - ) -> bool: ... + ) -> None: + """Callback to run when a non-fatal error occurs in the processor. + + Args: + queue_item: The session queue item, if one was being executed when the error occurred. + error_type: The type of error, e.g. "ValueError". + error_message: The error message, e.g. "Invalid value". + error_traceback: The stringified error traceback. + """ + ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 49277a105d..eec835af87 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -30,7 +30,7 @@ from .session_processor_common import SessionProcessorStatus class DefaultSessionRunner(SessionRunnerBase): - """Processes a single session's invocations""" + """Processes a single session's invocations.""" def __init__( self, @@ -40,6 +40,15 @@ class DefaultSessionRunner(SessionRunnerBase): on_node_error_callbacks: Optional[list[OnNodeError]] = None, on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None, ): + """ + Args: + on_before_run_session_callbacks: Callbacks to run before the session starts. + on_before_run_node_callbacks: Callbacks to run before each node starts. + on_after_run_node_callbacks: Callbacks to run after each node completes. + on_node_error_callbacks: Callbacks to run when a node errors. + on_after_run_session_callbacks: Callbacks to run after the session completes. + """ + self._on_before_run_session_callbacks = on_before_run_session_callbacks or [] self._on_before_run_node_callbacks = on_before_run_node_callbacks or [] self._on_after_run_node_callbacks = on_after_run_node_callbacks or [] @@ -47,14 +56,12 @@ class DefaultSessionRunner(SessionRunnerBase): self._on_after_run_session_callbacks = on_after_run_session_callbacks or [] def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): - """Start the session runner""" self._services = services self._cancel_event = cancel_event self._profiler = profiler def run(self, queue_item: SessionQueueItem): - """Run the graph""" - # Exceptions raised outside `run_node` are handled by the processor. + # Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here. self._on_before_run_session(queue_item=queue_item) @@ -78,14 +85,16 @@ class DefaultSessionRunner(SessionRunnerBase): if invocation is None or self._cancel_event.is_set(): break + self.run_node(invocation, queue_item) + + # The session is complete if all invocations have been run or there is an error on the session. if queue_item.session.is_complete() or self._cancel_event.is_set(): break self._on_after_run_session(queue_item=queue_item) def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): - """Run a single node in the graph""" try: # Any unhandled exception in this scope is an invocation error & will fail the graph with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): @@ -110,7 +119,7 @@ class DefaultSessionRunner(SessionRunnerBase): self._on_after_run_node(invocation, queue_item, output) except KeyboardInterrupt: - # TODO(MM2): Create an event for this + # TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here? pass except CanceledException: # When the user cancels the graph, we first set the cancel event. The event is checked @@ -137,6 +146,8 @@ class DefaultSessionRunner(SessionRunnerBase): ) def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: + """Run before a session is executed""" + # If profiling is enabled, start the profiler if self._profiler is not None: self._profiler.start(profile_id=queue_item.session_id) @@ -145,6 +156,8 @@ class DefaultSessionRunner(SessionRunnerBase): callback(queue_item=queue_item) def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: + """Run after a session is executed""" + # If we are profiling, stop the profiler and dump the profile & stats if self._profiler is not None: profile_path = self._profiler.stop() @@ -156,7 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase): # Update the queue item with the completed session self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - # Send complete event + # TODO(psyche): This feels jumbled - we should review separation of concerns here. + # Send complete event. The events service will receive this and update the queue item's status. self._services.events.emit_graph_execution_complete( queue_batch_id=queue_item.batch_id, queue_item_id=queue_item.item_id, @@ -175,6 +189,7 @@ class DefaultSessionRunner(SessionRunnerBase): def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" + # Send starting event self._services.events.emit_invocation_started( queue_batch_id=queue_item.batch_id, @@ -192,6 +207,7 @@ class DefaultSessionRunner(SessionRunnerBase): self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput ): """Run after a node is executed""" + # Send complete event on successful runs self._services.events.emit_invocation_complete( queue_batch_id=queue_item.batch_id, @@ -214,6 +230,8 @@ class DefaultSessionRunner(SessionRunnerBase): error_message: str, error_traceback: str, ): + """Run when a node errors""" + # Node errors do not get the full traceback. Only the queue item gets the full traceback. node_error = f"{error_type}: {error_message}" queue_item.session.set_node_error(invocation.id, node_error) @@ -356,8 +374,8 @@ class DefaultSessionProcessor(SessionProcessorBase): resume_event: ThreadEvent, cancel_event: ThreadEvent, ): - # Outermost processor try block; any unhandled exception is a fatal processor error try: + # Any unhandled exception in this block is a fatal processor error and will stop the processor. self._thread_semaphore.acquire() stop_event.clear() resume_event.set() @@ -365,8 +383,8 @@ class DefaultSessionProcessor(SessionProcessorBase): while not stop_event.is_set(): poll_now_event.clear() - # Middle processor try block; any unhandled exception is a non-fatal processor error try: + # Any unhandled exception in this block is a nonfatal processor error and will be handled. # If we are paused, wait for resume event resume_event.wait()