From 33f9fe2c864ad203db1ce9aabcd9a9f378751699 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 22 May 2024 19:07:57 +1000 Subject: [PATCH] tidy(app): rearrange proccessor --- .../session_processor_default.py | 134 +++++++++--------- 1 file changed, 68 insertions(+), 66 deletions(-) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 74e7dd3deb..6e60d20853 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -30,6 +30,8 @@ from .session_processor_common import SessionProcessorStatus def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: + """Formats a stacktrace as a string""" + return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) @@ -72,6 +74,54 @@ class DefaultSessionRunner(SessionRunnerBase): self._on_after_run_session(queue_item=queue_item) + def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): + """Run a single node in the graph""" + try: + # Any unhandled exception is an invocation error & will fail the graph + with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): + self._on_before_run_node(invocation, queue_item) + + data = InvocationContextData( + invocation=invocation, + source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], + queue_item=queue_item, + ) + context = build_invocation_context( + data=data, + services=self._services, + cancel_event=self._cancel_event, + ) + + # Invoke the node + outputs = invocation.invoke_internal(context=context, services=self._services) + # Save outputs and history + queue_item.session.complete(invocation.id, outputs) + + self._on_after_run_node(invocation, queue_item, outputs) + + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass + except Exception as e: + # Must extract the exception traceback here to not lose its stacktrace when we change scope + exc_type = type(e) + exc_value = e + exc_traceback = e.__traceback__ + assert exc_traceback is not None + self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) + def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler if self._profiler is not None: @@ -172,54 +222,6 @@ class DefaultSessionRunner(SessionRunnerBase): for callback in self._on_node_error_callbacks: callback(invocation, queue_item, exc_type, exc_value, exc_traceback) - def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): - """Run a single node in the graph""" - try: - # Any unhandled exception is an invocation error & will fail the graph - with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id): - self._on_before_run_node(invocation, queue_item) - - data = InvocationContextData( - invocation=invocation, - source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id], - queue_item=queue_item, - ) - context = build_invocation_context( - data=data, - services=self._services, - cancel_event=self._cancel_event, - ) - - # Invoke the node - outputs = invocation.invoke_internal(context=context, services=self._services) - # Save outputs and history - queue_item.session.complete(invocation.id, outputs) - - self._on_after_run_node(invocation, queue_item, outputs) - - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass - except Exception as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_type = type(e) - exc_value = e - exc_traceback = e.__traceback__ - assert exc_traceback is not None - self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback) - class DefaultSessionProcessor(SessionProcessorBase): def __init__( @@ -236,24 +238,6 @@ class DefaultSessionProcessor(SessionProcessorBase): self._thread_limit = thread_limit self._polling_interval = polling_interval - def _on_non_fatal_processor_error( - self, - queue_item: Optional[SessionQueueItem], - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, - ) -> None: - stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) - # Non-fatal error in processor - self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") - # Cancel the queue item - if queue_item is not None: - self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) - - for callback in self._on_non_fatal_processor_error_callbacks: - callback(exc_type, exc_value, exc_traceback, queue_item) - def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None @@ -396,3 +380,21 @@ class DefaultSessionProcessor(SessionProcessorBase): poll_now_event.clear() self._queue_item = None self._thread_semaphore.release() + + def _on_non_fatal_processor_error( + self, + queue_item: Optional[SessionQueueItem], + exc_type: type, + exc_value: BaseException, + exc_traceback: TracebackType, + ) -> None: + stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) + # Non-fatal error in processor + self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}") + # Cancel the queue item + if queue_item is not None: + self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) + self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) + + for callback in self._on_non_fatal_processor_error_callbacks: + callback(exc_type, exc_value, exc_traceback, queue_item)